m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/config.py
blob: f95efc7b0e9610115cf680374fab824b60da9c7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ConvConfig:
    def __init__(self, in_channels, out_channels, max_pool, size=3, stride=1, padding=1):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.max_pool = max_pool
        self.size = size
        self.stride = stride
        self.padding = padding

    def conv_args(self):
        return {
            'in_channels': self.in_channels,
            'out_channels': self.out_channels,
            'kernel_size': self.size,
            'stride': self.stride,
            'padding': self.padding,
        }

    def __repr__(self):
        return """in_channels: {}
out_channels: {}
size: {}
stride: {}
padding: {}
max_pool: {}
""".format(self.in_channels, self.out_channels, self.size, self.stride,
        self.padding, self.max_pool)

class LinearConfig:
    def __init__(self, in_features, out_features):
        self.in_features = in_features
        self.out_features = out_features

    def linear_args(self):
        return {
            'in_features': self.in_features,
            'out_features': self.out_features
        }

    def __repr__(self):
        return """in_features: {}
out_features: {}
""".format(self.in_features, self.out_features)