m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/confusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/confusion.py')
-rw-r--r--src/confusion.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/confusion.py b/src/confusion.py
new file mode 100644
index 0000000..dd6fe81
--- /dev/null
+++ b/src/confusion.py
@@ -0,0 +1,39 @@
+import sys
+import torch
+from sklearn.metrics import multilabel_confusion_matrix
+from analysis import ExperimentData
+from data import get_loaders
+from classification import target_transform, twoargmax
+
+if len(sys.argv) < 2:
+ print('provide results directory as argument')
+ exit(1)
+
+def predict(ys):
+ predictions = torch.zeros(ys.shape)
+ for i in range(ys.shape[0]):
+ maxes = twoargmax(ys[i])
+ predictions[i][maxes[0]] = 1
+ predictions[i][maxes[1]] = 1
+ return predictions
+
+def make_matrix(net, data, targets):
+ matrix = None
+ net.eval()
+ torch.no_grad()
+ predictions = net(data)
+ predictions = predict(predictions)
+ matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int())
+ return matrix
+
+experiment = ExperimentData(sys.argv[1])
+net = experiment.get_net()
+
+data = None
+targets = None
+_, train_loader = get_loaders([(0, False, False)], target_transform, 100)
+for idx, (data, targets) in enumerate(train_loader):
+ pass
+
+matrix = make_matrix(net, data, targets)
+print(matrix)