m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/experiment/experiment.py
blob: b06408fec602d040b940eaa8b91327dc548a035d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import subprocess
import os
import time

import pandas as pd
import torch

from train.train import train_model
from util.parse_config import parse_file
from model.encoder import EncoderModel

class Experiment:
    def __init__(self, file, prefix, device):
        self.file = file
        self.device = device
        self.make_dir(prefix)
        self.copy_config(file)

    def run(self):
        model_config, train_config = parse_file(self.file)
        model = EncoderModel(device=self.device, **model_config).to(self.device)
        train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config)
        self.save_model(model)
        self.save_metrics(train_losses, test_losses, accuracies)

    def save_model(self, model):
        torch.save(model.state_dict(), self.dir_path('net.pt'))

    def dir_path(self, file):
        return '{}/{}'.format(self.dirname, file)

    def save_metrics(self, train_losses, test_losses, accuracies):
        data_frame = pd.DataFrame({
            'train_loss': train_losses,
            'test_loss': test_losses,
            'accuracy': accuracies
        })
        data_frame.to_csv(self.dir_path('metrics.csv'))

    def make_dir(self, prefix):
        time_string = time.strftime('%Y%m%d%H%M%S')
        prefix = '' if prefix == '' else '{}-'.format(prefix)
        dirname = 'outputs/{}{}'.format(prefix, time_string)
        self.dirname = dirname
        os.mkdir(dirname)

    def copy_config(self, file):
        subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])