diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-04-29 11:33:54 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-04-29 11:33:54 +0200 |
commit | 8c5636428a5674c7743998ef8edb9dc51ce8c1c9 (patch) | |
tree | 83f64bdbc0cdb166e990e01c499b5414c544e495 | |
parent | 9e22e1e8a221ce55789438956055bb02b3069162 (diff) |
Add classification net
-rw-r--r-- | src/classification.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/src/classification.py b/src/classification.py new file mode 100644 index 0000000..110a8d9 --- /dev/null +++ b/src/classification.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + +def target_transform(labels): + return (labels > 0) + 0. + +def twoargmax(a): + l = list(zip(a, range(len(a)))) + l.sort() + return [x[1] for x in l[-2:]] + +def loss_function(output, target): + return F.binary_cross_entropy(output, target) + +def count_correct(output, target): + correct = 0 + for i in range(len(output)): + selected = twoargmax(output[i]) + both_correct = True + for selection in selected: + if target[i][selection] != 1: + both_correct = False + if both_correct: + correct += 1 + return correct + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) + self.conv2 = nn.Conv2d(6, 16, 3, padding=1) + self.conv3 = nn.Conv2d(16, 32, 3, padding=1) + self.fc1 = nn.Linear(288, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 6) + + def forward(self, x): + x = x.unsqueeze(1) + x = self.conv1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = F.max_pool2d(F.relu(self.conv3(x)), 2) + x = x.view(-1, self.num_flat_features(x)) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + x = torch.sigmoid(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features |