import subprocess import os import time import pandas as pd import matplotlib.pyplot as plt import torch from runner import Runner class Experiment: def __init__(self, file, prefix): self.make_dir(prefix) self.copy_config(file) self.metrics = ExperimentMetrics() self.runner = Runner(file, self.metrics) def run(self): self.runner.run() def save_results(self): data = self.metrics.get_dataframe() data.to_csv(self.dir_path('metrics.csv')) plt.plot(data['train_losses'], label='train loss') plt.plot(data['test_losses'], label='test loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.savefig(self.dir_path('losses.png')) plt.clf() plt.plot(data['test_accuracies'], label='test accuracy') plt.xlabel('Epoch') plt.ylabel('% correct') plt.legend() plt.savefig(self.dir_path('accuracies.png')) torch.save(self.runner.net.state_dict(), self.dir_path('net.pt')) def dir_path(self, file): return '{}/{}'.format(self.dirname, file) def make_dir(self, prefix): time_string = time.strftime('%Y%m%d%H%M%S') prefix = '' if prefix == '' else '{}-'.format(prefix) dirname = 'outputs/{}{}'.format(prefix, time_string) self.dirname = dirname os.mkdir(dirname) def copy_config(self, file): subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)]) class ExperimentMetrics: def __init__(self): self.train_losses = [] self.test_losses = [] self.test_accuracies = [] def add_train_loss(self, loss): self.train_losses.append(round(loss.tolist(), 3)) def add_test_metrics(self, loss, accuracy): self.test_losses.append(round(loss.tolist(), 3)) self.test_accuracies.append(accuracy) def get_dataframe(self): return pd.DataFrame({ 'train_losses': self.train_losses, 'test_losses': self.test_losses, 'test_accuracies': self.test_accuracies, })