m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/analysis.py
blob: f7d82d9b1b617ebfdb9a5fe1c78731c3b8abd9fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pandas as pd
import torch
import yaml
import parse_config
from net import Net

class ExperimentData:
    def __init__(self, directory):
        self.directory = directory
        self.config = self.get_config()
        self.metrics = self.get_metrics()

    def get_config(self):
        with open('{}/config.yaml'.format(self.directory)) as file:
            return yaml.safe_load(file)

    def get_metrics(self):
        return pd.read_csv('{}/metrics.csv'.format(self.directory))

    def get_net(self):
        (
            net_config,
            lr,
            self.epochs,
            batch_size,
            augmentations,
            target_transform,
            self.loss_function,
            self.count_correct
        ) = parse_config.parse(self.config)

        net = Net(**net_config)
        model_file = './{}/net.pt'.format(self.directory)
        net.load_state_dict(torch.load(model_file,
            map_location=torch.device('cpu')))
        return net