m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/util/parse_config.py
blob: 39b9c1348703c706032314f889902e88099be137 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import yaml

def parse_file(file):
    with open(file) as file:
        return parse_yaml(yaml.safe_load(file))

def parse_model_config(yaml):
    required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads']
    config = { key: yaml[key] for key in required_verbatim_params }

    config['input_dim'] = yaml['n_tokens']
    config['output_dim'] = yaml['max_count'] + 1

    if 'use_positional' in yaml:
        config['use_positional'] = yaml['use_positional']
    if 'use_feedforward' in yaml:
        config['use_feedforward'] = yaml['use_feedforward']
    if 'use_attention' in yaml:
        config['use_attention'] = yaml['use_attention']

    return config

def parse_train_config(yaml):
    required_verbatim_params = [
        'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count'
    ]
    config = { key: yaml[key] for key in required_verbatim_params }
    return config

def parse_yaml(yaml):
    model_config = parse_model_config(yaml)
    train_config = parse_train_config(yaml)
    return model_config, train_config