1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
|
import yaml
import torch
from config import ConvConfig, LinearConfig
import net_types
def get(dictionary, key):
return dictionary.get(key)
def parse_file(file):
with open(file) as file:
return parse(yaml.safe_load(file))
def parse_convolution(config):
max_pool = get(config, 'max_pool') or False
size = get(config, 'size') or 3
stride = get(config, 'stride') or 1
padding = get(config, 'padding') or 1
return ConvConfig(
config['in_channels'],
config['out_channels'],
max_pool,
size,
stride,
padding
)
def parse_convolutions(config):
convolutions = []
for convolution_config in config:
convolutions.append(parse_convolution(convolution_config))
return convolutions
def parse_linear(config):
return LinearConfig(config['in_features'], config['out_features'])
def parse_linears(config):
linears = []
for linear_config in config:
linears.append(parse_linear(linear_config))
return linears
# TODO: temporary placeholder
def foo(x):
return x
def parse_type(typ):
if typ == 'classification':
return 6, torch.sigmoid
elif typ == 'counting-small':
return 60, foo
else:
raise Exception('unknown network type: {}'.format(typ))
def parse_net(config):
convolutions = parse_convolutions(config['convolutions'])
linears = parse_linears(config['linears'])
batch_norm = get(config, 'batch_norm') or False
dropout = get(config, 'dropout') or False
return {
'convolutions': convolutions,
'linears': linears,
'batch_norm': batch_norm,
'dropout': dropout
}
def parse_type(typ):
net_type = net_types.types[typ]
return (
net_type.outputs,
net_type.finalizer,
net_type.target_transform,
net_type.loss_function,
net_type.count_correct
)
def parse_augmentations(config):
augmentations = []
for augmentation_config in config:
rotation = get(augmentation_config, 'rotation') or 0
vflip = get(augmentation_config, 'vflip') or False
hflip = get(augmentation_config, 'hflip') or False
augmentations.append((rotation, vflip, hflip))
return augmentations
def parse(config):
net = parse_net(config)
outputs, finalizer, target_transform, loss_function, count_correct = parse_type(config['type'])
net['outputs'] = outputs
net['finalizer'] = finalizer
lr = config['lr']
epochs = config['epochs']
batch_size = config['batch_size']
augmentations = parse_augmentations(get(config, 'augmentations') or [])
return (
net,
lr,
epochs,
batch_size,
augmentations,
target_transform,
loss_function,
count_correct
)
|