From a4e33358691431575d169a6102a945e16d132a44 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sat, 29 May 2021 14:14:33 +0200 Subject: Save model and metrics --- experiment/experiment.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'experiment/experiment.py') diff --git a/experiment/experiment.py b/experiment/experiment.py index 600732f..b06408f 100644 --- a/experiment/experiment.py +++ b/experiment/experiment.py @@ -2,6 +2,9 @@ import subprocess import os import time +import pandas as pd +import torch + from train.train import train_model from util.parse_config import parse_file from model.encoder import EncoderModel @@ -16,11 +19,24 @@ class Experiment: def run(self): model_config, train_config = parse_file(self.file) model = EncoderModel(device=self.device, **model_config).to(self.device) - train_model(model, device=self.device, **train_config) + train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config) + self.save_model(model) + self.save_metrics(train_losses, test_losses, accuracies) + + def save_model(self, model): + torch.save(model.state_dict(), self.dir_path('net.pt')) def dir_path(self, file): return '{}/{}'.format(self.dirname, file) + def save_metrics(self, train_losses, test_losses, accuracies): + data_frame = pd.DataFrame({ + 'train_loss': train_losses, + 'test_loss': test_losses, + 'accuracy': accuracies + }) + data_frame.to_csv(self.dir_path('metrics.csv')) + def make_dir(self, prefix): time_string = time.strftime('%Y%m%d%H%M%S') prefix = '' if prefix == '' else '{}-'.format(prefix) -- cgit v1.2.3