blob: dd6fe81e85c20cad2314f6f2a65707bc103aff29 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
import sys
import torch
from sklearn.metrics import multilabel_confusion_matrix
from analysis import ExperimentData
from data import get_loaders
from classification import target_transform, twoargmax
if len(sys.argv) < 2:
print('provide results directory as argument')
exit(1)
def predict(ys):
predictions = torch.zeros(ys.shape)
for i in range(ys.shape[0]):
maxes = twoargmax(ys[i])
predictions[i][maxes[0]] = 1
predictions[i][maxes[1]] = 1
return predictions
def make_matrix(net, data, targets):
matrix = None
net.eval()
torch.no_grad()
predictions = net(data)
predictions = predict(predictions)
matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int())
return matrix
experiment = ExperimentData(sys.argv[1])
net = experiment.get_net()
data = None
targets = None
_, train_loader = get_loaders([(0, False, False)], target_transform, 100)
for idx, (data, targets) in enumerate(train_loader):
pass
matrix = make_matrix(net, data, targets)
print(matrix)
|