From 194a2ac67bb956e05e5a076ecf6893a73e569589 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sun, 2 May 2021 20:35:04 +0200 Subject: Implement small counting network --- src/counting_small.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/counting_small.py (limited to 'src/counting_small.py') diff --git a/src/counting_small.py b/src/counting_small.py new file mode 100644 index 0000000..48d967c --- /dev/null +++ b/src/counting_small.py @@ -0,0 +1,30 @@ +import torch + +def target_transform(labels): + count_labels = torch.zeros(60) + for shape_index in range(len(labels)): + count_labels[int(shape_index * 10 + labels[shape_index])] = 1 + return count_labels + +def loss_function(output, target): + a = torch.tensor(range(10)).repeat(6 * target.shape[0]).reshape(target.shape) + errors = a - target + + return torch.sum(output * errors ** 2) / output.shape[0] + +def count_correct(output, target): + print('output shape', output.shape) + print('target shape', target.shape) + output = output.reshape(output.shape[0], 6, 10) + target = target.reshape(target.shape[0], 6, 10) + target = torch.argmax(target, dim=2) + predictions = torch.argmax(output, dim=2) + correct = torch.min(predictions.reshape(target.shape) == target, dim=1).values + return torch.sum(correct).tolist() + +def finalizer(x): + x = x.reshape(x.shape[0], 6, 10) + x = torch.softmax(x, 1) + return x.reshape(x.shape[0], 60) + +outputs = 60 -- cgit v1.2.3