m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/classification.py
blob: 7a93e7fe079b8f362d8f0a0cc46158286c9a667d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch.nn.functional as F
import torch

def target_transform(labels):
    return (labels > 0) + 0.

def twoargmax(a):
    l = list(zip(a, range(len(a))))
    l.sort()
    return [x[1] for x in l[-2:]]

def loss_function(output, target):
    return F.binary_cross_entropy(output, target)

def count_correct(output, target):
    correct = 0
    for i in range(len(output)):
        selected = twoargmax(output[i])
        both_correct = True
        for selection in selected:
            if target[i][selection] != 1:
                both_correct = False
        if both_correct:
            correct += 1
    return correct

def finalizer(x):
    return torch.sigmoid(x)

outputs = 6