From 3602753148198a1e03cbf681da5fbd1d6a3512e5 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sat, 1 May 2021 15:00:46 +0200 Subject: Remove rigid network --- src/classification.py | 32 +++----------------------------- 1 file changed, 3 insertions(+), 29 deletions(-) (limited to 'src') diff --git a/src/classification.py b/src/classification.py index 110a8d9..7a93e7f 100644 --- a/src/classification.py +++ b/src/classification.py @@ -25,33 +25,7 @@ def count_correct(output, target): correct += 1 return correct -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 6, 3, padding=1) - self.conv2 = nn.Conv2d(6, 16, 3, padding=1) - self.conv3 = nn.Conv2d(16, 32, 3, padding=1) - self.fc1 = nn.Linear(288, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 6) +def finalizer(x): + return torch.sigmoid(x) - def forward(self, x): - x = x.unsqueeze(1) - x = self.conv1(x) - x = F.relu(x) - x = F.max_pool2d(x, (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), 2) - x = F.max_pool2d(F.relu(self.conv3(x)), 2) - x = x.view(-1, self.num_flat_features(x)) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - x = torch.sigmoid(x) - return x - - def num_flat_features(self, x): - size = x.size()[1:] # all dimensions except the batch dimension - num_features = 1 - for s in size: - num_features *= s - return num_features +outputs = 6 -- cgit v1.2.3