m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 15:00:46 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 15:00:46 +0200
commit3602753148198a1e03cbf681da5fbd1d6a3512e5 (patch)
tree6a28630ba0259bf0a81e9d17c455f352672e920b
parentb813b8c336c1114032f0f2cad7c0a2f9a9c89669 (diff)
Remove rigid network
-rw-r--r--src/classification.py32
1 files changed, 3 insertions, 29 deletions
diff --git a/src/classification.py b/src/classification.py
index 110a8d9..7a93e7f 100644
--- a/src/classification.py
+++ b/src/classification.py
@@ -25,33 +25,7 @@ def count_correct(output, target):
correct += 1
return correct
-class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
- self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
- self.conv3 = nn.Conv2d(16, 32, 3, padding=1)
- self.fc1 = nn.Linear(288, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 6)
+def finalizer(x):
+ return torch.sigmoid(x)
- def forward(self, x):
- x = x.unsqueeze(1)
- x = self.conv1(x)
- x = F.relu(x)
- x = F.max_pool2d(x, (2, 2))
- x = F.max_pool2d(F.relu(self.conv2(x)), 2)
- x = F.max_pool2d(F.relu(self.conv3(x)), 2)
- x = x.view(-1, self.num_flat_features(x))
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- x = torch.sigmoid(x)
- return x
-
- def num_flat_features(self, x):
- size = x.size()[1:] # all dimensions except the batch dimension
- num_features = 1
- for s in size:
- num_features *= s
- return num_features
+outputs = 6