m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/parse_config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/parse_config.py')
-rw-r--r--src/parse_config.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/src/parse_config.py b/src/parse_config.py
new file mode 100644
index 0000000..91b4a41
--- /dev/null
+++ b/src/parse_config.py
@@ -0,0 +1,104 @@
+import yaml
+import torch
+
+from config import ConvConfig, LinearConfig
+import net_types
+
+def get(dictionary, key):
+ return dictionary.get(key)
+
+def parse_file(file):
+ with open(file) as file:
+ return parse(yaml.safe_load(file))
+
+def parse_convolution(config):
+ max_pool = get(config, 'max_pool') or False
+ size = get(config, 'size') or 3
+ stride = get(config, 'stride') or 1
+ padding = get(config, 'padding') or 1
+ return ConvConfig(
+ config['in_channels'],
+ config['out_channels'],
+ max_pool,
+ size,
+ stride,
+ padding
+ )
+
+def parse_convolutions(config):
+ convolutions = []
+ for convolution_config in config:
+ convolutions.append(parse_convolution(convolution_config))
+ return convolutions
+
+def parse_linear(config):
+ return LinearConfig(config['in_features'], config['out_features'])
+
+def parse_linears(config):
+ linears = []
+ for linear_config in config:
+ linears.append(parse_linear(linear_config))
+ return linears
+
+# TODO: temporary placeholder
+def foo(x):
+ return x
+
+def parse_type(typ):
+ if typ == 'classification':
+ return 6, torch.sigmoid
+ elif typ == 'counting-small':
+ return 60, foo
+ else:
+ raise Exception('unknown network type: {}'.format(typ))
+
+def parse_net(config):
+ convolutions = parse_convolutions(config['convolutions'])
+ linears = parse_linears(config['linears'])
+ batch_norm = get(config, 'batch_norm') or False
+ dropout = get(config, 'dropout') or False
+ return {
+ 'convolutions': convolutions,
+ 'linears': linears,
+ 'batch_norm': batch_norm,
+ 'dropout': dropout
+ }
+
+def parse_type(typ):
+ net_type = net_types.types[typ]
+ return (
+ net_type.outputs,
+ net_type.finalizer,
+ net_type.target_transform,
+ net_type.loss_function,
+ net_type.count_correct
+ )
+
+def parse_augmentations(config):
+ augmentations = []
+ for augmentation_config in config:
+ rotation = get(augmentation_config, 'rotation') or 0
+ vflip = get(augmentation_config, 'vflip') or False
+ hflip = get(augmentation_config, 'hflip') or False
+ augmentations.append((rotation, vflip, hflip))
+ return augmentations
+
+def parse(config):
+ net = parse_net(config)
+ outputs, finalizer, target_transform, loss_function, count_correct = parse_type(config['type'])
+ net['outputs'] = outputs
+ net['finalizer'] = finalizer
+ lr = config['lr']
+ epochs = config['epochs']
+ batch_size = config['batch_size']
+ augmentations = parse_augmentations(get(config, 'augmentations') or [])
+ return (
+ net,
+ lr,
+ epochs,
+ batch_size,
+ augmentations,
+ target_transform,
+ loss_function,
+ count_correct
+ )