import torch.nn.functional as F import torch def target_transform(labels): first_shape_count = None first_shape = None second_shape = None for i in range(6): if labels[i] > 0: if first_shape is None: first_shape = i first_shape_count = labels[i] else: second_shape = i break first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2 pair_offset = first_offset + (second_shape - first_shape) - 1 offset = int(pair_offset * 9 + (first_shape_count - 1)) new_labels = torch.zeros(135) new_labels[offset] = 1 return new_labels def loss_function(output, target): return F.binary_cross_entropy(output, target) def count_correct(output, target): output = torch.argmax(output, dim=1) target = torch.argmax(target, dim=1) return torch.sum(output == target).tolist() def finalizer(x): return torch.softmax(x, 1) outputs = 135