diff options
-rw-r--r-- | src/config.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..f95efc7 --- /dev/null +++ b/src/config.py @@ -0,0 +1,43 @@ +class ConvConfig: + def __init__(self, in_channels, out_channels, max_pool, size=3, stride=1, padding=1): + self.in_channels = in_channels + self.out_channels = out_channels + self.max_pool = max_pool + self.size = size + self.stride = stride + self.padding = padding + + def conv_args(self): + return { + 'in_channels': self.in_channels, + 'out_channels': self.out_channels, + 'kernel_size': self.size, + 'stride': self.stride, + 'padding': self.padding, + } + + def __repr__(self): + return """in_channels: {} +out_channels: {} +size: {} +stride: {} +padding: {} +max_pool: {} +""".format(self.in_channels, self.out_channels, self.size, self.stride, + self.padding, self.max_pool) + +class LinearConfig: + def __init__(self, in_features, out_features): + self.in_features = in_features + self.out_features = out_features + + def linear_args(self): + return { + 'in_features': self.in_features, + 'out_features': self.out_features + } + + def __repr__(self): + return """in_features: {} +out_features: {} +""".format(self.in_features, self.out_features) |