m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/experiment.py')
-rw-r--r--src/experiment.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/src/experiment.py b/src/experiment.py
new file mode 100644
index 0000000..dd2033c
--- /dev/null
+++ b/src/experiment.py
@@ -0,0 +1,68 @@
+import subprocess
+import os
+import time
+
+import pandas as pd
+import matplotlib.pyplot as plt
+
+from runner import Runner
+
+class Experiment:
+ def __init__(self, file):
+ self.make_dir()
+ self.copy_config(file)
+ self.metrics = ExperimentMetrics()
+ self.runner = Runner(file, self.metrics)
+
+ def run(self):
+ self.runner.run()
+
+ def save_results(self):
+ data = self.metrics.get_dataframe()
+ data.to_csv(self.dir_path('metrics.csv'))
+
+ plt.plot(data['train_losses'], label='train loss')
+ plt.plot(data['test_losses'], label='test loss')
+ plt.xlabel('Epoch')
+ plt.ylabel('Loss')
+ plt.legend()
+ plt.savefig(self.dir_path('losses.png'))
+ plt.clf()
+
+ plt.plot(data['test_accuracies'], label='test accuracy')
+ plt.xlabel('Epoch')
+ plt.ylabel('% correct')
+ plt.legend()
+ plt.savefig(self.dir_path('accuracies.png'))
+
+ def dir_path(self, file):
+ return '{}/{}'.format(self.dirname, file)
+
+ def make_dir(self):
+ time_string = time.strftime('%Y%m%d%H%M%S')
+ dirname = 'outputs/{}'.format(time_string)
+ self.dirname = dirname
+ os.mkdir(dirname)
+
+ def copy_config(self, file):
+ subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])
+
+class ExperimentMetrics:
+ def __init__(self):
+ self.train_losses = []
+ self.test_losses = []
+ self.test_accuracies = []
+
+ def add_train_loss(self, loss):
+ self.train_losses.append(loss)
+
+ def add_test_metrics(self, loss, accuracy):
+ self.test_losses.append(loss)
+ self.test_accuracies.append(accuracy)
+
+ def get_dataframe(self):
+ return pd.DataFrame({
+ 'train_losses': self.train_losses,
+ 'test_losses': self.test_losses,
+ 'test_accuracies': self.test_accuracies,
+ })