m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/experiment
diff options
context:
space:
mode:
Diffstat (limited to 'experiment')
-rw-r--r--experiment/experiment.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/experiment/experiment.py b/experiment/experiment.py
index 600732f..b06408f 100644
--- a/experiment/experiment.py
+++ b/experiment/experiment.py
@@ -2,6 +2,9 @@ import subprocess
import os
import time
+import pandas as pd
+import torch
+
from train.train import train_model
from util.parse_config import parse_file
from model.encoder import EncoderModel
@@ -16,11 +19,24 @@ class Experiment:
def run(self):
model_config, train_config = parse_file(self.file)
model = EncoderModel(device=self.device, **model_config).to(self.device)
- train_model(model, device=self.device, **train_config)
+ train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config)
+ self.save_model(model)
+ self.save_metrics(train_losses, test_losses, accuracies)
+
+ def save_model(self, model):
+ torch.save(model.state_dict(), self.dir_path('net.pt'))
def dir_path(self, file):
return '{}/{}'.format(self.dirname, file)
+ def save_metrics(self, train_losses, test_losses, accuracies):
+ data_frame = pd.DataFrame({
+ 'train_loss': train_losses,
+ 'test_loss': test_losses,
+ 'accuracy': accuracies
+ })
+ data_frame.to_csv(self.dir_path('metrics.csv'))
+
def make_dir(self, prefix):
time_string = time.strftime('%Y%m%d%H%M%S')
prefix = '' if prefix == '' else '{}-'.format(prefix)