1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
class ConvConfig:
def __init__(self, in_channels, out_channels, max_pool, size=3, stride=1, padding=1):
self.in_channels = in_channels
self.out_channels = out_channels
self.max_pool = max_pool
self.size = size
self.stride = stride
self.padding = padding
def conv_args(self):
return {
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'kernel_size': self.size,
'stride': self.stride,
'padding': self.padding,
}
def __repr__(self):
return """in_channels: {}
out_channels: {}
size: {}
stride: {}
padding: {}
max_pool: {}
""".format(self.in_channels, self.out_channels, self.size, self.stride,
self.padding, self.max_pool)
class LinearConfig:
def __init__(self, in_features, out_features):
self.in_features = in_features
self.out_features = out_features
def linear_args(self):
return {
'in_features': self.in_features,
'out_features': self.out_features
}
def __repr__(self):
return """in_features: {}
out_features: {}
""".format(self.in_features, self.out_features)
|