class ConvConfig: def __init__(self, in_channels, out_channels, max_pool, size=3, stride=1, padding=1): self.in_channels = in_channels self.out_channels = out_channels self.max_pool = max_pool self.size = size self.stride = stride self.padding = padding def conv_args(self): return { 'in_channels': self.in_channels, 'out_channels': self.out_channels, 'kernel_size': self.size, 'stride': self.stride, 'padding': self.padding, } def __repr__(self): return """in_channels: {} out_channels: {} size: {} stride: {} padding: {} max_pool: {} """.format(self.in_channels, self.out_channels, self.size, self.stride, self.padding, self.max_pool) class LinearConfig: def __init__(self, in_features, out_features): self.in_features = in_features self.out_features = out_features def linear_args(self): return { 'in_features': self.in_features, 'out_features': self.out_features } def __repr__(self): return """in_features: {} out_features: {} """.format(self.in_features, self.out_features)