From aa5b7986f94673874d9c66359ebbb223ded00767 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sat, 1 May 2021 18:15:17 +0200 Subject: Add missing file --- src/net.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/net.py diff --git a/src/net.py b/src/net.py new file mode 100644 index 0000000..af4e984 --- /dev/null +++ b/src/net.py @@ -0,0 +1,57 @@ +import torch.nn as nn + +class Net(nn.Module): + def __init__(self, convolutions, linears, outputs, finalizer, batch_norm=False, dropout=False): + super(Net, self).__init__() + + self.finalizer = finalizer + + self.convolutions = self.make_convolutions(convolutions, batch_norm, + dropout) + self.linears = self.make_linears(linears, batch_norm, dropout) + self.final_linear = nn.Linear( + linears[-1].linear_args()['out_features'], outputs) + + def forward(self, x): + x = x.unsqueeze(1) + x = self.convolutions(x) + x = x.view(-1, self.num_flat_features(x)) + x = self.linears(x) + x = self.final_linear(x) + x = self.finalizer(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features + + def make_convolutions(self, convolutions, batch_norm, dropout): + layers = [] + for convolution in convolutions: + conv_args = convolution.conv_args() + layers.append(nn.Conv2d(**conv_args)) + if batch_norm: + layers.append(nn.BatchNorm2d(conv_args['out_channels'])) + layers.append(nn.ReLU()) + if dropout: + layers.append(nn.Dropout2d(dropout)) + if convolution.max_pool: + layers.append(nn.MaxPool2d(2)) + + return nn.Sequential(*layers) + + def make_linears(self, linears, batch_norm, dropout): + layers = [] + for linear in linears: + linear_args = linear.linear_args() + layers.append(nn.Linear(**linear_args)) + if batch_norm: + layers.append(nn.BatchNorm1d(linear_args['out_features'])) + layers.append(nn.ReLU()) + if dropout: + layers.append(nn.Dropout(dropout)) + + return nn.Sequential(*layers) -- cgit v1.2.3