import torch.nn as nn class Net(nn.Module): def __init__(self, convolutions, linears, outputs, finalizer, batch_norm=False, dropout=False): super(Net, self).__init__() self.finalizer = finalizer self.convolutions = self.make_convolutions(convolutions, batch_norm, dropout) self.linears = self.make_linears(linears, batch_norm, dropout) self.final_linear = nn.Linear( linears[-1].linear_args()['out_features'], outputs) def forward(self, x): x = x.unsqueeze(1) x = self.convolutions(x) x = x.view(-1, self.num_flat_features(x)) x = self.linears(x) x = self.final_linear(x) x = self.finalizer(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features def make_convolutions(self, convolutions, batch_norm, dropout): layers = [] for convolution in convolutions: conv_args = convolution.conv_args() layers.append(nn.Conv2d(**conv_args)) if batch_norm: layers.append(nn.BatchNorm2d(conv_args['out_channels'])) layers.append(nn.ReLU()) if dropout: layers.append(nn.Dropout2d(dropout)) if convolution.max_pool: layers.append(nn.MaxPool2d(2)) return nn.Sequential(*layers) def make_linears(self, linears, batch_norm, dropout): layers = [] for linear in linears: linear_args = linear.linear_args() layers.append(nn.Linear(**linear_args)) if batch_norm: layers.append(nn.BatchNorm1d(linear_args['out_features'])) layers.append(nn.ReLU()) if dropout: layers.append(nn.Dropout(dropout)) return nn.Sequential(*layers)