From 8c5636428a5674c7743998ef8edb9dc51ce8c1c9 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Thu, 29 Apr 2021 11:33:54 +0200 Subject: Add classification net --- src/classification.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/classification.py (limited to 'src') diff --git a/src/classification.py b/src/classification.py new file mode 100644 index 0000000..110a8d9 --- /dev/null +++ b/src/classification.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + +def target_transform(labels): + return (labels > 0) + 0. + +def twoargmax(a): + l = list(zip(a, range(len(a)))) + l.sort() + return [x[1] for x in l[-2:]] + +def loss_function(output, target): + return F.binary_cross_entropy(output, target) + +def count_correct(output, target): + correct = 0 + for i in range(len(output)): + selected = twoargmax(output[i]) + both_correct = True + for selection in selected: + if target[i][selection] != 1: + both_correct = False + if both_correct: + correct += 1 + return correct + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) + self.conv2 = nn.Conv2d(6, 16, 3, padding=1) + self.conv3 = nn.Conv2d(16, 32, 3, padding=1) + self.fc1 = nn.Linear(288, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 6) + + def forward(self, x): + x = x.unsqueeze(1) + x = self.conv1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = F.max_pool2d(F.relu(self.conv3(x)), 2) + x = x.view(-1, self.num_flat_features(x)) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + x = torch.sigmoid(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features -- cgit v1.2.3