diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:00:46 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:00:46 +0200 |
commit | 3602753148198a1e03cbf681da5fbd1d6a3512e5 (patch) | |
tree | 6a28630ba0259bf0a81e9d17c455f352672e920b /src | |
parent | b813b8c336c1114032f0f2cad7c0a2f9a9c89669 (diff) |
Remove rigid network
Diffstat (limited to 'src')
-rw-r--r-- | src/classification.py | 32 |
1 files changed, 3 insertions, 29 deletions
diff --git a/src/classification.py b/src/classification.py index 110a8d9..7a93e7f 100644 --- a/src/classification.py +++ b/src/classification.py @@ -25,33 +25,7 @@ def count_correct(output, target): correct += 1 return correct -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 6, 3, padding=1) - self.conv2 = nn.Conv2d(6, 16, 3, padding=1) - self.conv3 = nn.Conv2d(16, 32, 3, padding=1) - self.fc1 = nn.Linear(288, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 6) +def finalizer(x): + return torch.sigmoid(x) - def forward(self, x): - x = x.unsqueeze(1) - x = self.conv1(x) - x = F.relu(x) - x = F.max_pool2d(x, (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), 2) - x = F.max_pool2d(F.relu(self.conv3(x)), 2) - x = x.view(-1, self.num_flat_features(x)) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - x = torch.sigmoid(x) - return x - - def num_flat_features(self, x): - size = x.size()[1:] # all dimensions except the batch dimension - num_features = 1 - for s in size: - num_features *= s - return num_features +outputs = 6 |