diff options
Diffstat (limited to 'src/experiment.py')
-rw-r--r-- | src/experiment.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/src/experiment.py b/src/experiment.py new file mode 100644 index 0000000..dd2033c --- /dev/null +++ b/src/experiment.py @@ -0,0 +1,68 @@ +import subprocess +import os +import time + +import pandas as pd +import matplotlib.pyplot as plt + +from runner import Runner + +class Experiment: + def __init__(self, file): + self.make_dir() + self.copy_config(file) + self.metrics = ExperimentMetrics() + self.runner = Runner(file, self.metrics) + + def run(self): + self.runner.run() + + def save_results(self): + data = self.metrics.get_dataframe() + data.to_csv(self.dir_path('metrics.csv')) + + plt.plot(data['train_losses'], label='train loss') + plt.plot(data['test_losses'], label='test loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + plt.savefig(self.dir_path('losses.png')) + plt.clf() + + plt.plot(data['test_accuracies'], label='test accuracy') + plt.xlabel('Epoch') + plt.ylabel('% correct') + plt.legend() + plt.savefig(self.dir_path('accuracies.png')) + + def dir_path(self, file): + return '{}/{}'.format(self.dirname, file) + + def make_dir(self): + time_string = time.strftime('%Y%m%d%H%M%S') + dirname = 'outputs/{}'.format(time_string) + self.dirname = dirname + os.mkdir(dirname) + + def copy_config(self, file): + subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)]) + +class ExperimentMetrics: + def __init__(self): + self.train_losses = [] + self.test_losses = [] + self.test_accuracies = [] + + def add_train_loss(self, loss): + self.train_losses.append(loss) + + def add_test_metrics(self, loss, accuracy): + self.test_losses.append(loss) + self.test_accuracies.append(accuracy) + + def get_dataframe(self): + return pd.DataFrame({ + 'train_losses': self.train_losses, + 'test_losses': self.test_losses, + 'test_accuracies': self.test_accuracies, + }) |