m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-02 20:35:04 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-02 20:35:04 +0200
commit194a2ac67bb956e05e5a076ecf6893a73e569589 (patch)
treee2829940ee771176aa7a73fd5c6b0612f03ae442
parent45b67cbadd51d81bd55e759a07f1e1a5fbf5d5c4 (diff)
Implement small counting network
-rw-r--r--src/counting_small.py30
-rw-r--r--src/net_types.py4
2 files changed, 33 insertions, 1 deletions
diff --git a/src/counting_small.py b/src/counting_small.py
new file mode 100644
index 0000000..48d967c
--- /dev/null
+++ b/src/counting_small.py
@@ -0,0 +1,30 @@
+import torch
+
+def target_transform(labels):
+ count_labels = torch.zeros(60)
+ for shape_index in range(len(labels)):
+ count_labels[int(shape_index * 10 + labels[shape_index])] = 1
+ return count_labels
+
+def loss_function(output, target):
+ a = torch.tensor(range(10)).repeat(6 * target.shape[0]).reshape(target.shape)
+ errors = a - target
+
+ return torch.sum(output * errors ** 2) / output.shape[0]
+
+def count_correct(output, target):
+ print('output shape', output.shape)
+ print('target shape', target.shape)
+ output = output.reshape(output.shape[0], 6, 10)
+ target = target.reshape(target.shape[0], 6, 10)
+ target = torch.argmax(target, dim=2)
+ predictions = torch.argmax(output, dim=2)
+ correct = torch.min(predictions.reshape(target.shape) == target, dim=1).values
+ return torch.sum(correct).tolist()
+
+def finalizer(x):
+ x = x.reshape(x.shape[0], 6, 10)
+ x = torch.softmax(x, 1)
+ return x.reshape(x.shape[0], 60)
+
+outputs = 60
diff --git a/src/net_types.py b/src/net_types.py
index bb52d50..14925c8 100644
--- a/src/net_types.py
+++ b/src/net_types.py
@@ -1,5 +1,7 @@
import classification
+import counting_small
types = {
- 'classification': classification
+ 'classification': classification,
+ 'counting-small': counting_small
}