import subprocess import os import time import pandas as pd import torch from train.train import train_model from util.parse_config import parse_file from model.encoder import EncoderModel class Experiment: def __init__(self, file, prefix, device): self.file = file self.device = device self.make_dir(prefix) self.copy_config(file) def run(self): model_config, train_config = parse_file(self.file) model = EncoderModel(device=self.device, **model_config).to(self.device) train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config) self.save_model(model) self.save_metrics(train_losses, test_losses, accuracies) def save_model(self, model): torch.save(model.state_dict(), self.dir_path('net.pt')) def dir_path(self, file): return '{}/{}'.format(self.dirname, file) def save_metrics(self, train_losses, test_losses, accuracies): data_frame = pd.DataFrame({ 'train_loss': train_losses, 'test_loss': test_losses, 'accuracy': accuracies }) data_frame.to_csv(self.dir_path('metrics.csv')) def make_dir(self, prefix): time_string = time.strftime('%Y%m%d%H%M%S') prefix = '' if prefix == '' else '{}-'.format(prefix) dirname = 'outputs/{}{}'.format(prefix, time_string) self.dirname = dirname os.mkdir(dirname) def copy_config(self, file): subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])