diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/counting_big.py | 36 | ||||
| -rw-r--r-- | src/net_types.py | 4 | 
2 files changed, 39 insertions, 1 deletions
| diff --git a/src/counting_big.py b/src/counting_big.py new file mode 100644 index 0000000..a5fb80c --- /dev/null +++ b/src/counting_big.py @@ -0,0 +1,36 @@ +import torch.nn.functional as F +import torch + +def target_transform(labels): +    first_shape_count = None +    first_shape = None +    second_shape = None +    for i in range(6): +        if labels[i] > 0: +            if first_shape is None: +                first_shape = i +                first_shape_count = labels[i] +            else: +                second_shape = i +                break + +    first_offset = 15 - (6 - first_shape) * (5 - first_shape) / 2 +    pair_offset = first_offset + (second_shape - first_shape) - 1 +    offset = int(pair_offset * 9 + (first_shape_count - 1)) +    new_labels = torch.zeros(135) +    new_labels[offset] = 1 +    return new_labels + + +def loss_function(output, target): +    return F.binary_cross_entropy(output, target) + +def count_correct(output, target): +    output = torch.argmax(output, dim=1) +    target = torch.argmax(target, dim=1) +    return torch.sum(output == target).tolist() + +def finalizer(x): +    return torch.softmax(x, 1) + +outputs = 135 diff --git a/src/net_types.py b/src/net_types.py index 14925c8..1f308a4 100644 --- a/src/net_types.py +++ b/src/net_types.py @@ -1,7 +1,9 @@  import classification  import counting_small +import counting_big  types = {      'classification': classification, -    'counting-small': counting_small +    'counting-small': counting_small, +    'counting-big': counting_big,  } |