diff options
Diffstat (limited to 'src/analysis.py')
-rw-r--r-- | src/analysis.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/src/analysis.py b/src/analysis.py new file mode 100644 index 0000000..f7d82d9 --- /dev/null +++ b/src/analysis.py @@ -0,0 +1,36 @@ +import pandas as pd +import torch +import yaml +import parse_config +from net import Net + +class ExperimentData: + def __init__(self, directory): + self.directory = directory + self.config = self.get_config() + self.metrics = self.get_metrics() + + def get_config(self): + with open('{}/config.yaml'.format(self.directory)) as file: + return yaml.safe_load(file) + + def get_metrics(self): + return pd.read_csv('{}/metrics.csv'.format(self.directory)) + + def get_net(self): + ( + net_config, + lr, + self.epochs, + batch_size, + augmentations, + target_transform, + self.loss_function, + self.count_correct + ) = parse_config.parse(self.config) + + net = Net(**net_config) + model_file = './{}/net.pt'.format(self.directory) + net.load_state_dict(torch.load(model_file, + map_location=torch.device('cpu'))) + return net |