diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-02 20:35:04 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-02 20:35:04 +0200 |
commit | 194a2ac67bb956e05e5a076ecf6893a73e569589 (patch) | |
tree | e2829940ee771176aa7a73fd5c6b0612f03ae442 /src | |
parent | 45b67cbadd51d81bd55e759a07f1e1a5fbf5d5c4 (diff) |
Implement small counting network
Diffstat (limited to 'src')
-rw-r--r-- | src/counting_small.py | 30 | ||||
-rw-r--r-- | src/net_types.py | 4 |
2 files changed, 33 insertions, 1 deletions
diff --git a/src/counting_small.py b/src/counting_small.py new file mode 100644 index 0000000..48d967c --- /dev/null +++ b/src/counting_small.py @@ -0,0 +1,30 @@ +import torch + +def target_transform(labels): + count_labels = torch.zeros(60) + for shape_index in range(len(labels)): + count_labels[int(shape_index * 10 + labels[shape_index])] = 1 + return count_labels + +def loss_function(output, target): + a = torch.tensor(range(10)).repeat(6 * target.shape[0]).reshape(target.shape) + errors = a - target + + return torch.sum(output * errors ** 2) / output.shape[0] + +def count_correct(output, target): + print('output shape', output.shape) + print('target shape', target.shape) + output = output.reshape(output.shape[0], 6, 10) + target = target.reshape(target.shape[0], 6, 10) + target = torch.argmax(target, dim=2) + predictions = torch.argmax(output, dim=2) + correct = torch.min(predictions.reshape(target.shape) == target, dim=1).values + return torch.sum(correct).tolist() + +def finalizer(x): + x = x.reshape(x.shape[0], 6, 10) + x = torch.softmax(x, 1) + return x.reshape(x.shape[0], 60) + +outputs = 60 diff --git a/src/net_types.py b/src/net_types.py index bb52d50..14925c8 100644 --- a/src/net_types.py +++ b/src/net_types.py @@ -1,5 +1,7 @@ import classification +import counting_small types = { - 'classification': classification + 'classification': classification, + 'counting-small': counting_small } |