import sys import torch from sklearn.metrics import multilabel_confusion_matrix from analysis import ExperimentData from data import get_loaders from classification import target_transform, twoargmax if len(sys.argv) < 2: print('provide results directory as argument') exit(1) def predict(ys): predictions = torch.zeros(ys.shape) for i in range(ys.shape[0]): maxes = twoargmax(ys[i]) predictions[i][maxes[0]] = 1 predictions[i][maxes[1]] = 1 return predictions def make_matrix(net, data, targets): matrix = None net.eval() torch.no_grad() predictions = net(data) predictions = predict(predictions) matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int()) return matrix experiment = ExperimentData(sys.argv[1]) net = experiment.get_net() data = None targets = None _, train_loader = get_loaders([(0, False, False)], target_transform, 100) for idx, (data, targets) in enumerate(train_loader): pass matrix = make_matrix(net, data, targets) print(matrix)