m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/analysis.py36
-rw-r--r--src/confusion.py39
2 files changed, 75 insertions, 0 deletions
diff --git a/src/analysis.py b/src/analysis.py
new file mode 100644
index 0000000..f7d82d9
--- /dev/null
+++ b/src/analysis.py
@@ -0,0 +1,36 @@
+import pandas as pd
+import torch
+import yaml
+import parse_config
+from net import Net
+
+class ExperimentData:
+ def __init__(self, directory):
+ self.directory = directory
+ self.config = self.get_config()
+ self.metrics = self.get_metrics()
+
+ def get_config(self):
+ with open('{}/config.yaml'.format(self.directory)) as file:
+ return yaml.safe_load(file)
+
+ def get_metrics(self):
+ return pd.read_csv('{}/metrics.csv'.format(self.directory))
+
+ def get_net(self):
+ (
+ net_config,
+ lr,
+ self.epochs,
+ batch_size,
+ augmentations,
+ target_transform,
+ self.loss_function,
+ self.count_correct
+ ) = parse_config.parse(self.config)
+
+ net = Net(**net_config)
+ model_file = './{}/net.pt'.format(self.directory)
+ net.load_state_dict(torch.load(model_file,
+ map_location=torch.device('cpu')))
+ return net
diff --git a/src/confusion.py b/src/confusion.py
new file mode 100644
index 0000000..dd6fe81
--- /dev/null
+++ b/src/confusion.py
@@ -0,0 +1,39 @@
+import sys
+import torch
+from sklearn.metrics import multilabel_confusion_matrix
+from analysis import ExperimentData
+from data import get_loaders
+from classification import target_transform, twoargmax
+
+if len(sys.argv) < 2:
+ print('provide results directory as argument')
+ exit(1)
+
+def predict(ys):
+ predictions = torch.zeros(ys.shape)
+ for i in range(ys.shape[0]):
+ maxes = twoargmax(ys[i])
+ predictions[i][maxes[0]] = 1
+ predictions[i][maxes[1]] = 1
+ return predictions
+
+def make_matrix(net, data, targets):
+ matrix = None
+ net.eval()
+ torch.no_grad()
+ predictions = net(data)
+ predictions = predict(predictions)
+ matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int())
+ return matrix
+
+experiment = ExperimentData(sys.argv[1])
+net = experiment.get_net()
+
+data = None
+targets = None
+_, train_loader = get_loaders([(0, False, False)], target_transform, 100)
+for idx, (data, targets) in enumerate(train_loader):
+ pass
+
+matrix = make_matrix(net, data, targets)
+print(matrix)