blob: 39b9c1348703c706032314f889902e88099be137 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import yaml
def parse_file(file):
with open(file) as file:
return parse_yaml(yaml.safe_load(file))
def parse_model_config(yaml):
required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads']
config = { key: yaml[key] for key in required_verbatim_params }
config['input_dim'] = yaml['n_tokens']
config['output_dim'] = yaml['max_count'] + 1
if 'use_positional' in yaml:
config['use_positional'] = yaml['use_positional']
if 'use_feedforward' in yaml:
config['use_feedforward'] = yaml['use_feedforward']
if 'use_attention' in yaml:
config['use_attention'] = yaml['use_attention']
return config
def parse_train_config(yaml):
required_verbatim_params = [
'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count'
]
config = { key: yaml[key] for key in required_verbatim_params }
return config
def parse_yaml(yaml):
model_config = parse_model_config(yaml)
train_config = parse_train_config(yaml)
return model_config, train_config
|