m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/counting_big.py36
-rw-r--r--src/net_types.py4
2 files changed, 39 insertions, 1 deletions
diff --git a/src/counting_big.py b/src/counting_big.py
new file mode 100644
index 0000000..a5fb80c
--- /dev/null
+++ b/src/counting_big.py
@@ -0,0 +1,36 @@
+import torch.nn.functional as F
+import torch
+
+def target_transform(labels):
+ first_shape_count = None
+ first_shape = None
+ second_shape = None
+ for i in range(6):
+ if labels[i] > 0:
+ if first_shape is None:
+ first_shape = i
+ first_shape_count = labels[i]
+ else:
+ second_shape = i
+ break
+
+ first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2
+ pair_offset = first_offset + (second_shape - first_shape) - 1
+ offset = int(pair_offset * 9 + (first_shape_count - 1))
+ new_labels = torch.zeros(135)
+ new_labels[offset] = 1
+ return new_labels
+
+
+def loss_function(output, target):
+ return F.binary_cross_entropy(output, target)
+
+def count_correct(output, target):
+ output = torch.argmax(output, dim=1)
+ target = torch.argmax(target, dim=1)
+ return torch.sum(output == target).tolist()
+
+def finalizer(x):
+ return torch.softmax(x, 1)
+
+outputs = 135
diff --git a/src/net_types.py b/src/net_types.py
index 14925c8..1f308a4 100644
--- a/src/net_types.py
+++ b/src/net_types.py
@@ -1,7 +1,9 @@
import classification
import counting_small
+import counting_big
types = {
'classification': classification,
- 'counting-small': counting_small
+ 'counting-small': counting_small,
+ 'counting-big': counting_big,
}