diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:20:36 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:20:36 +0200 |
commit | 290d55c4353a7374da14d67bc9ab3d33c236fa95 (patch) | |
tree | f7531f12e1f78e32b59cabf2cf3570b6c5869a5f /src/experiment.py | |
parent | e6ea98728380a222459049987ddbb858464741d3 (diff) |
Implement configurable experiment runner
Diffstat (limited to 'src/experiment.py')
-rw-r--r-- | src/experiment.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/src/experiment.py b/src/experiment.py new file mode 100644 index 0000000..dd2033c --- /dev/null +++ b/src/experiment.py @@ -0,0 +1,68 @@ +import subprocess +import os +import time + +import pandas as pd +import matplotlib.pyplot as plt + +from runner import Runner + +class Experiment: + def __init__(self, file): + self.make_dir() + self.copy_config(file) + self.metrics = ExperimentMetrics() + self.runner = Runner(file, self.metrics) + + def run(self): + self.runner.run() + + def save_results(self): + data = self.metrics.get_dataframe() + data.to_csv(self.dir_path('metrics.csv')) + + plt.plot(data['train_losses'], label='train loss') + plt.plot(data['test_losses'], label='test loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + plt.savefig(self.dir_path('losses.png')) + plt.clf() + + plt.plot(data['test_accuracies'], label='test accuracy') + plt.xlabel('Epoch') + plt.ylabel('% correct') + plt.legend() + plt.savefig(self.dir_path('accuracies.png')) + + def dir_path(self, file): + return '{}/{}'.format(self.dirname, file) + + def make_dir(self): + time_string = time.strftime('%Y%m%d%H%M%S') + dirname = 'outputs/{}'.format(time_string) + self.dirname = dirname + os.mkdir(dirname) + + def copy_config(self, file): + subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)]) + +class ExperimentMetrics: + def __init__(self): + self.train_losses = [] + self.test_losses = [] + self.test_accuracies = [] + + def add_train_loss(self, loss): + self.train_losses.append(loss) + + def add_test_metrics(self, loss, accuracy): + self.test_losses.append(loss) + self.test_accuracies.append(accuracy) + + def get_dataframe(self): + return pd.DataFrame({ + 'train_losses': self.train_losses, + 'test_losses': self.test_losses, + 'test_accuracies': self.test_accuracies, + }) |