diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:59:03 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:59:03 +0200 |
commit | 03c817412c638c905dea6c1a1967691f3ced57b8 (patch) | |
tree | fb60d8b7a1a94341aab1548f95a37a8d9f3c9112 | |
parent | 3e4924d0dcec2eba4f3019b55178cd1c3b70a474 (diff) |
Add configurable run script
-rw-r--r-- | experiment/__init__.py | 0 | ||||
-rw-r--r-- | experiment/__main__.py | 17 | ||||
-rw-r--r-- | experiment/experiment.py | 32 | ||||
-rw-r--r-- | util/parse_config.py | 31 |
4 files changed, 80 insertions, 0 deletions
diff --git a/experiment/__init__.py b/experiment/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/experiment/__init__.py diff --git a/experiment/__main__.py b/experiment/__main__.py new file mode 100644 index 0000000..bf4de94 --- /dev/null +++ b/experiment/__main__.py @@ -0,0 +1,17 @@ +import sys + +import torch + +from experiment.experiment import Experiment + +file = None +if len(sys.argv) < 2: + print('Provide YAML configuration file as argument') + exit(1) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +file = sys.argv[1] +prefix = sys.argv[2] if len(sys.argv) > 2 else '' + +experiment = Experiment(file, prefix, device) +experiment.run() diff --git a/experiment/experiment.py b/experiment/experiment.py new file mode 100644 index 0000000..600732f --- /dev/null +++ b/experiment/experiment.py @@ -0,0 +1,32 @@ +import subprocess +import os +import time + +from train.train import train_model +from util.parse_config import parse_file +from model.encoder import EncoderModel + +class Experiment: + def __init__(self, file, prefix, device): + self.file = file + self.device = device + self.make_dir(prefix) + self.copy_config(file) + + def run(self): + model_config, train_config = parse_file(self.file) + model = EncoderModel(device=self.device, **model_config).to(self.device) + train_model(model, device=self.device, **train_config) + + def dir_path(self, file): + return '{}/{}'.format(self.dirname, file) + + def make_dir(self, prefix): + time_string = time.strftime('%Y%m%d%H%M%S') + prefix = '' if prefix == '' else '{}-'.format(prefix) + dirname = 'outputs/{}{}'.format(prefix, time_string) + self.dirname = dirname + os.mkdir(dirname) + + def copy_config(self, file): + subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)]) diff --git a/util/parse_config.py b/util/parse_config.py new file mode 100644 index 0000000..00a4fd4 --- /dev/null +++ b/util/parse_config.py @@ -0,0 +1,31 @@ +import yaml + +def parse_file(file): + with open(file) as file: + return parse_yaml(yaml.safe_load(file)) + +def parse_model_config(yaml): + required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads'] + config = { key: yaml[key] for key in required_verbatim_params } + + config['input_dim'] = yaml['n_tokens'] + config['output_dim'] = yaml['max_count'] + 1 + + if 'use_positional' in yaml: + config['use_positional'] = yaml['use_positional'] + if 'use_feedforward' in yaml: + config['use_feedforward'] = yaml['use_feedforward'] + + return config + +def parse_train_config(yaml): + required_verbatim_params = [ + 'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count' + ] + config = { key: yaml[key] for key in required_verbatim_params } + return config + +def parse_yaml(yaml): + model_config = parse_model_config(yaml) + train_config = parse_train_config(yaml) + return model_config, train_config |