blob: f7d82d9b1b617ebfdb9a5fe1c78731c3b8abd9fb (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
import pandas as pd
import torch
import yaml
import parse_config
from net import Net
class ExperimentData:
def __init__(self, directory):
self.directory = directory
self.config = self.get_config()
self.metrics = self.get_metrics()
def get_config(self):
with open('{}/config.yaml'.format(self.directory)) as file:
return yaml.safe_load(file)
def get_metrics(self):
return pd.read_csv('{}/metrics.csv'.format(self.directory))
def get_net(self):
(
net_config,
lr,
self.epochs,
batch_size,
augmentations,
target_transform,
self.loss_function,
self.count_correct
) = parse_config.parse(self.config)
net = Net(**net_config)
model_file = './{}/net.pt'.format(self.directory)
net.load_state_dict(torch.load(model_file,
map_location=torch.device('cpu')))
return net
|