blob: 747c1c704f41cf7eacea930a9ee129453d81a34f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
import torch.nn.functional as F
import torch
def target_transform(labels):
first_shape_count = None
first_shape = None
second_shape = None
for i in range(6):
if labels[i] > 0:
if first_shape is None:
first_shape = i
first_shape_count = labels[i]
else:
second_shape = i
break
first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2
pair_offset = first_offset + (second_shape - first_shape) - 1
offset = int(pair_offset * 9 + (first_shape_count - 1))
new_labels = torch.zeros(135)
new_labels[offset] = 1
return new_labels
def loss_function(output, target):
return F.binary_cross_entropy(output, target)
def count_correct(output, target):
output = torch.argmax(output, dim=1)
target = torch.argmax(target, dim=1)
return torch.sum(output == target).tolist()
def finalizer(x):
return torch.softmax(x, 1)
outputs = 135
|