diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/classification.py | 32 | 
1 files changed, 3 insertions, 29 deletions
| diff --git a/src/classification.py b/src/classification.py index 110a8d9..7a93e7f 100644 --- a/src/classification.py +++ b/src/classification.py @@ -25,33 +25,7 @@ def count_correct(output, target):              correct += 1      return correct -class Net(nn.Module): -    def __init__(self): -        super(Net, self).__init__() -        self.conv1 = nn.Conv2d(1, 6, 3, padding=1) -        self.conv2 = nn.Conv2d(6, 16, 3, padding=1) -        self.conv3 = nn.Conv2d(16, 32, 3, padding=1) -        self.fc1 = nn.Linear(288, 120) -        self.fc2 = nn.Linear(120, 84) -        self.fc3 = nn.Linear(84, 6) +def finalizer(x): +    return torch.sigmoid(x) -    def forward(self, x): -        x = x.unsqueeze(1) -        x = self.conv1(x) -        x = F.relu(x) -        x = F.max_pool2d(x, (2, 2)) -        x = F.max_pool2d(F.relu(self.conv2(x)), 2) -        x = F.max_pool2d(F.relu(self.conv3(x)), 2) -        x = x.view(-1, self.num_flat_features(x)) -        x = F.relu(self.fc1(x)) -        x = F.relu(self.fc2(x)) -        x = self.fc3(x) -        x = torch.sigmoid(x) -        return x - -    def num_flat_features(self, x): -        size = x.size()[1:]  # all dimensions except the batch dimension -        num_features = 1 -        for s in size: -            num_features *= s -        return num_features +outputs = 6 |