m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/net.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/src/net.py b/src/net.py
new file mode 100644
index 0000000..af4e984
--- /dev/null
+++ b/src/net.py
@@ -0,0 +1,57 @@
+import torch.nn as nn
+
+class Net(nn.Module):
+ def __init__(self, convolutions, linears, outputs, finalizer, batch_norm=False, dropout=False):
+ super(Net, self).__init__()
+
+ self.finalizer = finalizer
+
+ self.convolutions = self.make_convolutions(convolutions, batch_norm,
+ dropout)
+ self.linears = self.make_linears(linears, batch_norm, dropout)
+ self.final_linear = nn.Linear(
+ linears[-1].linear_args()['out_features'], outputs)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ x = self.convolutions(x)
+ x = x.view(-1, self.num_flat_features(x))
+ x = self.linears(x)
+ x = self.final_linear(x)
+ x = self.finalizer(x)
+ return x
+
+ def num_flat_features(self, x):
+ size = x.size()[1:] # all dimensions except the batch dimension
+ num_features = 1
+ for s in size:
+ num_features *= s
+ return num_features
+
+ def make_convolutions(self, convolutions, batch_norm, dropout):
+ layers = []
+ for convolution in convolutions:
+ conv_args = convolution.conv_args()
+ layers.append(nn.Conv2d(**conv_args))
+ if batch_norm:
+ layers.append(nn.BatchNorm2d(conv_args['out_channels']))
+ layers.append(nn.ReLU())
+ if dropout:
+ layers.append(nn.Dropout2d(dropout))
+ if convolution.max_pool:
+ layers.append(nn.MaxPool2d(2))
+
+ return nn.Sequential(*layers)
+
+ def make_linears(self, linears, batch_norm, dropout):
+ layers = []
+ for linear in linears:
+ linear_args = linear.linear_args()
+ layers.append(nn.Linear(**linear_args))
+ if batch_norm:
+ layers.append(nn.BatchNorm1d(linear_args['out_features']))
+ layers.append(nn.ReLU())
+ if dropout:
+ layers.append(nn.Dropout(dropout))
+
+ return nn.Sequential(*layers)