diff options
-rw-r--r-- | src/analysis.py | 36 | ||||
-rw-r--r-- | src/confusion.py | 39 |
2 files changed, 75 insertions, 0 deletions
diff --git a/src/analysis.py b/src/analysis.py new file mode 100644 index 0000000..f7d82d9 --- /dev/null +++ b/src/analysis.py @@ -0,0 +1,36 @@ +import pandas as pd +import torch +import yaml +import parse_config +from net import Net + +class ExperimentData: + def __init__(self, directory): + self.directory = directory + self.config = self.get_config() + self.metrics = self.get_metrics() + + def get_config(self): + with open('{}/config.yaml'.format(self.directory)) as file: + return yaml.safe_load(file) + + def get_metrics(self): + return pd.read_csv('{}/metrics.csv'.format(self.directory)) + + def get_net(self): + ( + net_config, + lr, + self.epochs, + batch_size, + augmentations, + target_transform, + self.loss_function, + self.count_correct + ) = parse_config.parse(self.config) + + net = Net(**net_config) + model_file = './{}/net.pt'.format(self.directory) + net.load_state_dict(torch.load(model_file, + map_location=torch.device('cpu'))) + return net diff --git a/src/confusion.py b/src/confusion.py new file mode 100644 index 0000000..dd6fe81 --- /dev/null +++ b/src/confusion.py @@ -0,0 +1,39 @@ +import sys +import torch +from sklearn.metrics import multilabel_confusion_matrix +from analysis import ExperimentData +from data import get_loaders +from classification import target_transform, twoargmax + +if len(sys.argv) < 2: + print('provide results directory as argument') + exit(1) + +def predict(ys): + predictions = torch.zeros(ys.shape) + for i in range(ys.shape[0]): + maxes = twoargmax(ys[i]) + predictions[i][maxes[0]] = 1 + predictions[i][maxes[1]] = 1 + return predictions + +def make_matrix(net, data, targets): + matrix = None + net.eval() + torch.no_grad() + predictions = net(data) + predictions = predict(predictions) + matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int()) + return matrix + +experiment = ExperimentData(sys.argv[1]) +net = experiment.get_net() + +data = None +targets = None +_, train_loader = get_loaders([(0, False, False)], target_transform, 100) +for idx, (data, targets) in enumerate(train_loader): + pass + +matrix = make_matrix(net, data, targets) +print(matrix) |