diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/counting_big.py | 36 | ||||
-rw-r--r-- | src/net_types.py | 4 |
2 files changed, 39 insertions, 1 deletions
diff --git a/src/counting_big.py b/src/counting_big.py new file mode 100644 index 0000000..a5fb80c --- /dev/null +++ b/src/counting_big.py @@ -0,0 +1,36 @@ +import torch.nn.functional as F +import torch + +def target_transform(labels): + first_shape_count = None + first_shape = None + second_shape = None + for i in range(6): + if labels[i] > 0: + if first_shape is None: + first_shape = i + first_shape_count = labels[i] + else: + second_shape = i + break + + first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2 + pair_offset = first_offset + (second_shape - first_shape) - 1 + offset = int(pair_offset * 9 + (first_shape_count - 1)) + new_labels = torch.zeros(135) + new_labels[offset] = 1 + return new_labels + + +def loss_function(output, target): + return F.binary_cross_entropy(output, target) + +def count_correct(output, target): + output = torch.argmax(output, dim=1) + target = torch.argmax(target, dim=1) + return torch.sum(output == target).tolist() + +def finalizer(x): + return torch.softmax(x, 1) + +outputs = 135 diff --git a/src/net_types.py b/src/net_types.py index 14925c8..1f308a4 100644 --- a/src/net_types.py +++ b/src/net_types.py @@ -1,7 +1,9 @@ import classification import counting_small +import counting_big types = { 'classification': classification, - 'counting-small': counting_small + 'counting-small': counting_small, + 'counting-big': counting_big, } |