m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/experiment/experiment.py
blob: 600732f94e7f9b95e0a0874ac1fd0551a423982e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import subprocess
import os
import time

from train.train import train_model
from util.parse_config import parse_file
from model.encoder import EncoderModel

class Experiment:
    def __init__(self, file, prefix, device):
        self.file = file
        self.device = device
        self.make_dir(prefix)
        self.copy_config(file)

    def run(self):
        model_config, train_config = parse_file(self.file)
        model = EncoderModel(device=self.device, **model_config).to(self.device)
        train_model(model, device=self.device, **train_config)

    def dir_path(self, file):
        return '{}/{}'.format(self.dirname, file)

    def make_dir(self, prefix):
        time_string = time.strftime('%Y%m%d%H%M%S')
        prefix = '' if prefix == '' else '{}-'.format(prefix)
        dirname = 'outputs/{}{}'.format(prefix, time_string)
        self.dirname = dirname
        os.mkdir(dirname)

    def copy_config(self, file):
        subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])