diff options
Diffstat (limited to 'src/confusion.py')
-rw-r--r-- | src/confusion.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/confusion.py b/src/confusion.py new file mode 100644 index 0000000..dd6fe81 --- /dev/null +++ b/src/confusion.py @@ -0,0 +1,39 @@ +import sys +import torch +from sklearn.metrics import multilabel_confusion_matrix +from analysis import ExperimentData +from data import get_loaders +from classification import target_transform, twoargmax + +if len(sys.argv) < 2: + print('provide results directory as argument') + exit(1) + +def predict(ys): + predictions = torch.zeros(ys.shape) + for i in range(ys.shape[0]): + maxes = twoargmax(ys[i]) + predictions[i][maxes[0]] = 1 + predictions[i][maxes[1]] = 1 + return predictions + +def make_matrix(net, data, targets): + matrix = None + net.eval() + torch.no_grad() + predictions = net(data) + predictions = predict(predictions) + matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int()) + return matrix + +experiment = ExperimentData(sys.argv[1]) +net = experiment.get_net() + +data = None +targets = None +_, train_loader = get_loaders([(0, False, False)], target_transform, 100) +for idx, (data, targets) in enumerate(train_loader): + pass + +matrix = make_matrix(net, data, targets) +print(matrix) |