import yaml def parse_file(file): with open(file) as file: return parse_yaml(yaml.safe_load(file)) def parse_model_config(yaml): required_verbatim_params = ['hidden_dim', 'd_ff', 'n_layers', 'num_heads'] config = { key: yaml[key] for key in required_verbatim_params } config['input_dim'] = yaml['n_tokens'] config['output_dim'] = yaml['max_count'] + 1 if 'use_positional' in yaml: config['use_positional'] = yaml['use_positional'] if 'use_feedforward' in yaml: config['use_feedforward'] = yaml['use_feedforward'] return config def parse_train_config(yaml): required_verbatim_params = [ 'lr', 'num_steps', 'batch_size', 'n_tokens', 'seqlen', 'max_count' ] config = { key: yaml[key] for key in required_verbatim_params } return config def parse_yaml(yaml): model_config = parse_model_config(yaml) train_config = parse_train_config(yaml) return model_config, train_config