From 194a2ac67bb956e05e5a076ecf6893a73e569589 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sun, 2 May 2021 20:35:04 +0200 Subject: Implement small counting network --- src/counting_small.py | 30 ++++++++++++++++++++++++++++++ src/net_types.py | 4 +++- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 src/counting_small.py diff --git a/src/counting_small.py b/src/counting_small.py new file mode 100644 index 0000000..48d967c --- /dev/null +++ b/src/counting_small.py @@ -0,0 +1,30 @@ +import torch + +def target_transform(labels): + count_labels = torch.zeros(60) + for shape_index in range(len(labels)): + count_labels[int(shape_index * 10 + labels[shape_index])] = 1 + return count_labels + +def loss_function(output, target): + a = torch.tensor(range(10)).repeat(6 * target.shape[0]).reshape(target.shape) + errors = a - target + + return torch.sum(output * errors ** 2) / output.shape[0] + +def count_correct(output, target): + print('output shape', output.shape) + print('target shape', target.shape) + output = output.reshape(output.shape[0], 6, 10) + target = target.reshape(target.shape[0], 6, 10) + target = torch.argmax(target, dim=2) + predictions = torch.argmax(output, dim=2) + correct = torch.min(predictions.reshape(target.shape) == target, dim=1).values + return torch.sum(correct).tolist() + +def finalizer(x): + x = x.reshape(x.shape[0], 6, 10) + x = torch.softmax(x, 1) + return x.reshape(x.shape[0], 60) + +outputs = 60 diff --git a/src/net_types.py b/src/net_types.py index bb52d50..14925c8 100644 --- a/src/net_types.py +++ b/src/net_types.py @@ -1,5 +1,7 @@ import classification +import counting_small types = { - 'classification': classification + 'classification': classification, + 'counting-small': counting_small } -- cgit v1.2.3