m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/util/parse_config.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/parse_config.py')
-rw-r--r--util/parse_config.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/util/parse_config.py b/util/parse_config.py
new file mode 100644
index 0000000..00a4fd4
--- /dev/null
+++ b/util/parse_config.py
@@ -0,0 +1,31 @@
+import yaml
+
+def parse_file(file):
+ with open(file) as file:
+ return parse_yaml(yaml.safe_load(file))
+
+def parse_model_config(yaml):
+ required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads']
+ config = { key: yaml[key] for key in required_verbatim_params }
+
+ config['input_dim'] = yaml['n_tokens']
+ config['output_dim'] = yaml['max_count'] + 1
+
+ if 'use_positional' in yaml:
+ config['use_positional'] = yaml['use_positional']
+ if 'use_feedforward' in yaml:
+ config['use_feedforward'] = yaml['use_feedforward']
+
+ return config
+
+def parse_train_config(yaml):
+ required_verbatim_params = [
+ 'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count'
+ ]
+ config = { key: yaml[key] for key in required_verbatim_params }
+ return config
+
+def parse_yaml(yaml):
+ model_config = parse_model_config(yaml)
+ train_config = parse_train_config(yaml)
+ return model_config, train_config