m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/confusion.py
blob: dd6fe81e85c20cad2314f6f2a65707bc103aff29 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
import torch
from sklearn.metrics import multilabel_confusion_matrix
from analysis import ExperimentData
from data import get_loaders
from classification import target_transform, twoargmax

if len(sys.argv) < 2:
    print('provide results directory as argument')
    exit(1)

def predict(ys):
    predictions = torch.zeros(ys.shape)
    for i in range(ys.shape[0]):
        maxes = twoargmax(ys[i])
        predictions[i][maxes[0]] = 1
        predictions[i][maxes[1]] = 1
    return predictions

def make_matrix(net, data, targets):
    matrix = None
    net.eval()
    torch.no_grad()
    predictions = net(data)
    predictions = predict(predictions)
    matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int())
    return matrix

experiment = ExperimentData(sys.argv[1])
net = experiment.get_net()

data = None
targets = None
_, train_loader = get_loaders([(0, False, False)], target_transform, 100)
for idx, (data, targets) in enumerate(train_loader):
    pass

matrix = make_matrix(net, data, targets)
print(matrix)