m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiment/__init__.py0
-rw-r--r--experiment/__main__.py17
-rw-r--r--experiment/experiment.py32
-rw-r--r--util/parse_config.py31
4 files changed, 80 insertions, 0 deletions
diff --git a/experiment/__init__.py b/experiment/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/experiment/__init__.py
diff --git a/experiment/__main__.py b/experiment/__main__.py
new file mode 100644
index 0000000..bf4de94
--- /dev/null
+++ b/experiment/__main__.py
@@ -0,0 +1,17 @@
+import sys
+
+import torch
+
+from experiment.experiment import Experiment
+
+file = None
+if len(sys.argv) < 2:
+ print('Provide YAML configuration file as argument')
+ exit(1)
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+file = sys.argv[1]
+prefix = sys.argv[2] if len(sys.argv) > 2 else ''
+
+experiment = Experiment(file, prefix, device)
+experiment.run()
diff --git a/experiment/experiment.py b/experiment/experiment.py
new file mode 100644
index 0000000..600732f
--- /dev/null
+++ b/experiment/experiment.py
@@ -0,0 +1,32 @@
+import subprocess
+import os
+import time
+
+from train.train import train_model
+from util.parse_config import parse_file
+from model.encoder import EncoderModel
+
+class Experiment:
+ def __init__(self, file, prefix, device):
+ self.file = file
+ self.device = device
+ self.make_dir(prefix)
+ self.copy_config(file)
+
+ def run(self):
+ model_config, train_config = parse_file(self.file)
+ model = EncoderModel(device=self.device, **model_config).to(self.device)
+ train_model(model, device=self.device, **train_config)
+
+ def dir_path(self, file):
+ return '{}/{}'.format(self.dirname, file)
+
+ def make_dir(self, prefix):
+ time_string = time.strftime('%Y%m%d%H%M%S')
+ prefix = '' if prefix == '' else '{}-'.format(prefix)
+ dirname = 'outputs/{}{}'.format(prefix, time_string)
+ self.dirname = dirname
+ os.mkdir(dirname)
+
+ def copy_config(self, file):
+ subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])
diff --git a/util/parse_config.py b/util/parse_config.py
new file mode 100644
index 0000000..00a4fd4
--- /dev/null
+++ b/util/parse_config.py
@@ -0,0 +1,31 @@
+import yaml
+
+def parse_file(file):
+ with open(file) as file:
+ return parse_yaml(yaml.safe_load(file))
+
+def parse_model_config(yaml):
+ required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads']
+ config = { key: yaml[key] for key in required_verbatim_params }
+
+ config['input_dim'] = yaml['n_tokens']
+ config['output_dim'] = yaml['max_count'] + 1
+
+ if 'use_positional' in yaml:
+ config['use_positional'] = yaml['use_positional']
+ if 'use_feedforward' in yaml:
+ config['use_feedforward'] = yaml['use_feedforward']
+
+ return config
+
+def parse_train_config(yaml):
+ required_verbatim_params = [
+ 'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count'
+ ]
+ config = { key: yaml[key] for key in required_verbatim_params }
+ return config
+
+def parse_yaml(yaml):
+ model_config = parse_model_config(yaml)
+ train_config = parse_train_config(yaml)
+ return model_config, train_config