import pandas as pd import torch import yaml import parse_config from net import Net class ExperimentData: def __init__(self, directory): self.directory = directory self.config = self.get_config() self.metrics = self.get_metrics() def get_config(self): with open('{}/config.yaml'.format(self.directory)) as file: return yaml.safe_load(file) def get_metrics(self): return pd.read_csv('{}/metrics.csv'.format(self.directory)) def get_net(self): ( net_config, lr, self.epochs, batch_size, augmentations, target_transform, self.loss_function, self.count_correct ) = parse_config.parse(self.config) net = Net(**net_config) model_file = './{}/net.pt'.format(self.directory) net.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'))) return net