diff options
Diffstat (limited to 'src/counting_big.py')
-rw-r--r-- | src/counting_big.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/src/counting_big.py b/src/counting_big.py new file mode 100644 index 0000000..a5fb80c --- /dev/null +++ b/src/counting_big.py @@ -0,0 +1,36 @@ +import torch.nn.functional as F +import torch + +def target_transform(labels): + first_shape_count = None + first_shape = None + second_shape = None + for i in range(6): + if labels[i] > 0: + if first_shape is None: + first_shape = i + first_shape_count = labels[i] + else: + second_shape = i + break + + first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2 + pair_offset = first_offset + (second_shape - first_shape) - 1 + offset = int(pair_offset * 9 + (first_shape_count - 1)) + new_labels = torch.zeros(135) + new_labels[offset] = 1 + return new_labels + + +def loss_function(output, target): + return F.binary_cross_entropy(output, target) + +def count_correct(output, target): + output = torch.argmax(output, dim=1) + target = torch.argmax(target, dim=1) + return torch.sum(output == target).tolist() + +def finalizer(x): + return torch.softmax(x, 1) + +outputs = 135 |