diff options
Diffstat (limited to 'experiment/experiment.py')
-rw-r--r-- | experiment/experiment.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/experiment/experiment.py b/experiment/experiment.py new file mode 100644 index 0000000..600732f --- /dev/null +++ b/experiment/experiment.py @@ -0,0 +1,32 @@ +import subprocess +import os +import time + +from train.train import train_model +from util.parse_config import parse_file +from model.encoder import EncoderModel + +class Experiment: + def __init__(self, file, prefix, device): + self.file = file + self.device = device + self.make_dir(prefix) + self.copy_config(file) + + def run(self): + model_config, train_config = parse_file(self.file) + model = EncoderModel(device=self.device, **model_config).to(self.device) + train_model(model, device=self.device, **train_config) + + def dir_path(self, file): + return '{}/{}'.format(self.dirname, file) + + def make_dir(self, prefix): + time_string = time.strftime('%Y%m%d%H%M%S') + prefix = '' if prefix == '' else '{}-'.format(prefix) + dirname = 'outputs/{}{}'.format(prefix, time_string) + self.dirname = dirname + os.mkdir(dirname) + + def copy_config(self, file): + subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)]) |