m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/analysis.py')
-rw-r--r--src/analysis.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/src/analysis.py b/src/analysis.py
new file mode 100644
index 0000000..f7d82d9
--- /dev/null
+++ b/src/analysis.py
@@ -0,0 +1,36 @@
+import pandas as pd
+import torch
+import yaml
+import parse_config
+from net import Net
+
+class ExperimentData:
+ def __init__(self, directory):
+ self.directory = directory
+ self.config = self.get_config()
+ self.metrics = self.get_metrics()
+
+ def get_config(self):
+ with open('{}/config.yaml'.format(self.directory)) as file:
+ return yaml.safe_load(file)
+
+ def get_metrics(self):
+ return pd.read_csv('{}/metrics.csv'.format(self.directory))
+
+ def get_net(self):
+ (
+ net_config,
+ lr,
+ self.epochs,
+ batch_size,
+ augmentations,
+ target_transform,
+ self.loss_function,
+ self.count_correct
+ ) = parse_config.parse(self.config)
+
+ net = Net(**net_config)
+ model_file = './{}/net.pt'.format(self.directory)
+ net.load_state_dict(torch.load(model_file,
+ map_location=torch.device('cpu')))
+ return net