diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 18:15:17 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 18:15:17 +0200 |
commit | aa5b7986f94673874d9c66359ebbb223ded00767 (patch) | |
tree | d59f8a569ef77d248288aecdc12afcb4218f20d9 /src | |
parent | 290d55c4353a7374da14d67bc9ab3d33c236fa95 (diff) |
Add missing file
Diffstat (limited to 'src')
-rw-r--r-- | src/net.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/src/net.py b/src/net.py new file mode 100644 index 0000000..af4e984 --- /dev/null +++ b/src/net.py @@ -0,0 +1,57 @@ +import torch.nn as nn + +class Net(nn.Module): + def __init__(self, convolutions, linears, outputs, finalizer, batch_norm=False, dropout=False): + super(Net, self).__init__() + + self.finalizer = finalizer + + self.convolutions = self.make_convolutions(convolutions, batch_norm, + dropout) + self.linears = self.make_linears(linears, batch_norm, dropout) + self.final_linear = nn.Linear( + linears[-1].linear_args()['out_features'], outputs) + + def forward(self, x): + x = x.unsqueeze(1) + x = self.convolutions(x) + x = x.view(-1, self.num_flat_features(x)) + x = self.linears(x) + x = self.final_linear(x) + x = self.finalizer(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features + + def make_convolutions(self, convolutions, batch_norm, dropout): + layers = [] + for convolution in convolutions: + conv_args = convolution.conv_args() + layers.append(nn.Conv2d(**conv_args)) + if batch_norm: + layers.append(nn.BatchNorm2d(conv_args['out_channels'])) + layers.append(nn.ReLU()) + if dropout: + layers.append(nn.Dropout2d(dropout)) + if convolution.max_pool: + layers.append(nn.MaxPool2d(2)) + + return nn.Sequential(*layers) + + def make_linears(self, linears, batch_norm, dropout): + layers = [] + for linear in linears: + linear_args = linear.linear_args() + layers.append(nn.Linear(**linear_args)) + if batch_norm: + layers.append(nn.BatchNorm1d(linear_args['out_features'])) + layers.append(nn.ReLU()) + if dropout: + layers.append(nn.Dropout(dropout)) + + return nn.Sequential(*layers) |