m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/counting_big.py
blob: a5fb80c9662cf984353a2062debfc7543e89a68f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch.nn.functional as F
import torch

def target_transform(labels):
    first_shape_count = None
    first_shape = None
    second_shape = None
    for i in range(6):
        if labels[i] > 0:
            if first_shape is None:
                first_shape = i
                first_shape_count = labels[i]
            else:
                second_shape = i
                break

    first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2
    pair_offset = first_offset + (second_shape - first_shape) - 1
    offset = int(pair_offset * 9 + (first_shape_count - 1))
    new_labels = torch.zeros(135)
    new_labels[offset] = 1
    return new_labels


def loss_function(output, target):
    return F.binary_cross_entropy(output, target)

def count_correct(output, target):
    output = torch.argmax(output, dim=1)
    target = torch.argmax(target, dim=1)
    return torch.sum(output == target).tolist()

def finalizer(x):
    return torch.softmax(x, 1)

outputs = 135