import yaml import torch from config import ConvConfig, LinearConfig import net_types def get(dictionary, key): return dictionary.get(key) def parse_file(file): with open(file) as file: return parse(yaml.safe_load(file)) def parse_convolution(config): max_pool = get(config, 'max_pool') or False size = get(config, 'size') or 3 stride = get(config, 'stride') or 1 padding = get(config, 'padding') or 1 return ConvConfig( config['in_channels'], config['out_channels'], max_pool, size, stride, padding ) def parse_convolutions(config): convolutions = [] for convolution_config in config: convolutions.append(parse_convolution(convolution_config)) return convolutions def parse_linear(config): return LinearConfig(config['in_features'], config['out_features']) def parse_linears(config): linears = [] for linear_config in config: linears.append(parse_linear(linear_config)) return linears # TODO: temporary placeholder def foo(x): return x def parse_type(typ): if typ == 'classification': return 6, torch.sigmoid elif typ == 'counting-small': return 60, foo else: raise Exception('unknown network type: {}'.format(typ)) def parse_net(config): convolutions = parse_convolutions(config['convolutions']) linears = parse_linears(config['linears']) batch_norm = get(config, 'batch_norm') or False dropout = get(config, 'dropout') or False return { 'convolutions': convolutions, 'linears': linears, 'batch_norm': batch_norm, 'dropout': dropout } def parse_type(typ): net_type = net_types.types[typ] return ( net_type.outputs, net_type.finalizer, net_type.target_transform, net_type.loss_function, net_type.count_correct ) def parse_augmentations(config): augmentations = [] for augmentation_config in config: rotation = get(augmentation_config, 'rotation') or 0 vflip = get(augmentation_config, 'vflip') or False hflip = get(augmentation_config, 'hflip') or False augmentations.append((rotation, vflip, hflip)) return augmentations def parse(config): net = parse_net(config) outputs, finalizer, target_transform, loss_function, count_correct = parse_type(config['type']) net['outputs'] = outputs net['finalizer'] = finalizer lr = config['lr'] epochs = config['epochs'] batch_size = config['batch_size'] augmentations = parse_augmentations(get(config, 'augmentations') or []) return ( net, lr, epochs, batch_size, augmentations, target_transform, loss_function, count_correct )