From 03c817412c638c905dea6c1a1967691f3ced57b8 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Thu, 27 May 2021 19:59:03 +0200 Subject: Add configurable run script --- util/parse_config.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 util/parse_config.py (limited to 'util') diff --git a/util/parse_config.py b/util/parse_config.py new file mode 100644 index 0000000..00a4fd4 --- /dev/null +++ b/util/parse_config.py @@ -0,0 +1,31 @@ +import yaml + +def parse_file(file): + with open(file) as file: + return parse_yaml(yaml.safe_load(file)) + +def parse_model_config(yaml): + required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads'] + config = { key: yaml[key] for key in required_verbatim_params } + + config['input_dim'] = yaml['n_tokens'] + config['output_dim'] = yaml['max_count'] + 1 + + if 'use_positional' in yaml: + config['use_positional'] = yaml['use_positional'] + if 'use_feedforward' in yaml: + config['use_feedforward'] = yaml['use_feedforward'] + + return config + +def parse_train_config(yaml): + required_verbatim_params = [ + 'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count' + ] + config = { key: yaml[key] for key in required_verbatim_params } + return config + +def parse_yaml(yaml): + model_config = parse_model_config(yaml) + train_config = parse_train_config(yaml) + return model_config, train_config -- cgit v1.2.3