m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/parse_config.py
blob: 91b4a41e4f2c9efd7c72fe0f5a5615f33b1369ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import yaml
import torch

from config import ConvConfig, LinearConfig
import net_types

def get(dictionary, key):
    return dictionary.get(key)

def parse_file(file):
    with open(file) as file:
        return parse(yaml.safe_load(file))

def parse_convolution(config):
    max_pool = get(config, 'max_pool') or False
    size = get(config, 'size') or 3
    stride = get(config, 'stride') or 1
    padding = get(config, 'padding') or 1
    return ConvConfig(
        config['in_channels'],
        config['out_channels'],
        max_pool,
        size,
        stride,
        padding
    )

def parse_convolutions(config):
    convolutions = []
    for convolution_config in config:
        convolutions.append(parse_convolution(convolution_config))
    return convolutions

def parse_linear(config):
    return LinearConfig(config['in_features'], config['out_features'])

def parse_linears(config):
    linears = []
    for linear_config in config:
        linears.append(parse_linear(linear_config))
    return linears

# TODO: temporary placeholder
def foo(x):
    return x

def parse_type(typ):
    if typ == 'classification':
        return 6, torch.sigmoid
    elif typ == 'counting-small':
        return 60, foo
    else:
        raise Exception('unknown network type: {}'.format(typ))

def parse_net(config):
    convolutions = parse_convolutions(config['convolutions'])
    linears = parse_linears(config['linears'])
    batch_norm = get(config, 'batch_norm') or False
    dropout = get(config, 'dropout') or False
    return {
        'convolutions': convolutions,
        'linears': linears,
        'batch_norm': batch_norm,
        'dropout': dropout
    }

def parse_type(typ):
    net_type = net_types.types[typ]
    return (
        net_type.outputs,
        net_type.finalizer,
        net_type.target_transform,
        net_type.loss_function,
        net_type.count_correct
    )

def parse_augmentations(config):
    augmentations = []
    for augmentation_config in config:
        rotation = get(augmentation_config, 'rotation') or 0
        vflip = get(augmentation_config, 'vflip') or False
        hflip = get(augmentation_config, 'hflip') or False
        augmentations.append((rotation, vflip, hflip))
    return augmentations

def parse(config):
    net = parse_net(config)
    outputs, finalizer, target_transform, loss_function, count_correct = parse_type(config['type'])
    net['outputs'] = outputs
    net['finalizer'] = finalizer
    lr = config['lr']
    epochs = config['epochs']
    batch_size = config['batch_size']
    augmentations = parse_augmentations(get(config, 'augmentations') or [])
    return (
        net,
        lr,
        epochs,
        batch_size,
        augmentations,
        target_transform,
        loss_function,
        count_correct
    )