diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/analysis.py | 36 | ||||
| -rw-r--r-- | src/confusion.py | 39 | 
2 files changed, 75 insertions, 0 deletions
| diff --git a/src/analysis.py b/src/analysis.py new file mode 100644 index 0000000..f7d82d9 --- /dev/null +++ b/src/analysis.py @@ -0,0 +1,36 @@ +import pandas as pd +import torch +import yaml +import parse_config +from net import Net + +class ExperimentData: +    def __init__(self, directory): +        self.directory = directory +        self.config = self.get_config() +        self.metrics = self.get_metrics() + +    def get_config(self): +        with open('{}/config.yaml'.format(self.directory)) as file: +            return yaml.safe_load(file) + +    def get_metrics(self): +        return pd.read_csv('{}/metrics.csv'.format(self.directory)) + +    def get_net(self): +        ( +            net_config, +            lr, +            self.epochs, +            batch_size, +            augmentations, +            target_transform, +            self.loss_function, +            self.count_correct +        ) = parse_config.parse(self.config) + +        net = Net(**net_config) +        model_file = './{}/net.pt'.format(self.directory) +        net.load_state_dict(torch.load(model_file, +            map_location=torch.device('cpu'))) +        return net diff --git a/src/confusion.py b/src/confusion.py new file mode 100644 index 0000000..dd6fe81 --- /dev/null +++ b/src/confusion.py @@ -0,0 +1,39 @@ +import sys +import torch +from sklearn.metrics import multilabel_confusion_matrix +from analysis import ExperimentData +from data import get_loaders +from classification import target_transform, twoargmax + +if len(sys.argv) < 2: +    print('provide results directory as argument') +    exit(1) + +def predict(ys): +    predictions = torch.zeros(ys.shape) +    for i in range(ys.shape[0]): +        maxes = twoargmax(ys[i]) +        predictions[i][maxes[0]] = 1 +        predictions[i][maxes[1]] = 1 +    return predictions + +def make_matrix(net, data, targets): +    matrix = None +    net.eval() +    torch.no_grad() +    predictions = net(data) +    predictions = predict(predictions) +    matrix = multilabel_confusion_matrix(targets.detach().int(), predictions.detach().int()) +    return matrix + +experiment = ExperimentData(sys.argv[1]) +net = experiment.get_net() + +data = None +targets = None +_, train_loader = get_loaders([(0, False, False)], target_transform, 100) +for idx, (data, targets) in enumerate(train_loader): +    pass + +matrix = make_matrix(net, data, targets) +print(matrix) |