diff options
Diffstat (limited to 'src/parse_config.py')
-rw-r--r-- | src/parse_config.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/src/parse_config.py b/src/parse_config.py new file mode 100644 index 0000000..91b4a41 --- /dev/null +++ b/src/parse_config.py @@ -0,0 +1,104 @@ +import yaml +import torch + +from config import ConvConfig, LinearConfig +import net_types + +def get(dictionary, key): + return dictionary.get(key) + +def parse_file(file): + with open(file) as file: + return parse(yaml.safe_load(file)) + +def parse_convolution(config): + max_pool = get(config, 'max_pool') or False + size = get(config, 'size') or 3 + stride = get(config, 'stride') or 1 + padding = get(config, 'padding') or 1 + return ConvConfig( + config['in_channels'], + config['out_channels'], + max_pool, + size, + stride, + padding + ) + +def parse_convolutions(config): + convolutions = [] + for convolution_config in config: + convolutions.append(parse_convolution(convolution_config)) + return convolutions + +def parse_linear(config): + return LinearConfig(config['in_features'], config['out_features']) + +def parse_linears(config): + linears = [] + for linear_config in config: + linears.append(parse_linear(linear_config)) + return linears + +# TODO: temporary placeholder +def foo(x): + return x + +def parse_type(typ): + if typ == 'classification': + return 6, torch.sigmoid + elif typ == 'counting-small': + return 60, foo + else: + raise Exception('unknown network type: {}'.format(typ)) + +def parse_net(config): + convolutions = parse_convolutions(config['convolutions']) + linears = parse_linears(config['linears']) + batch_norm = get(config, 'batch_norm') or False + dropout = get(config, 'dropout') or False + return { + 'convolutions': convolutions, + 'linears': linears, + 'batch_norm': batch_norm, + 'dropout': dropout + } + +def parse_type(typ): + net_type = net_types.types[typ] + return ( + net_type.outputs, + net_type.finalizer, + net_type.target_transform, + net_type.loss_function, + net_type.count_correct + ) + +def parse_augmentations(config): + augmentations = [] + for augmentation_config in config: + rotation = get(augmentation_config, 'rotation') or 0 + vflip = get(augmentation_config, 'vflip') or False + hflip = get(augmentation_config, 'hflip') or False + augmentations.append((rotation, vflip, hflip)) + return augmentations + +def parse(config): + net = parse_net(config) + outputs, finalizer, target_transform, loss_function, count_correct = parse_type(config['type']) + net['outputs'] = outputs + net['finalizer'] = finalizer + lr = config['lr'] + epochs = config['epochs'] + batch_size = config['batch_size'] + augmentations = parse_augmentations(get(config, 'augmentations') or []) + return ( + net, + lr, + epochs, + batch_size, + augmentations, + target_transform, + loss_function, + count_correct + ) |