m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-04-29 11:33:54 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-04-29 11:33:54 +0200
commit8c5636428a5674c7743998ef8edb9dc51ce8c1c9 (patch)
tree83f64bdbc0cdb166e990e01c499b5414c544e495
parent9e22e1e8a221ce55789438956055bb02b3069162 (diff)
Add classification net
-rw-r--r--src/classification.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/src/classification.py b/src/classification.py
new file mode 100644
index 0000000..110a8d9
--- /dev/null
+++ b/src/classification.py
@@ -0,0 +1,57 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+def target_transform(labels):
+ return (labels > 0) + 0.
+
+def twoargmax(a):
+ l = list(zip(a, range(len(a))))
+ l.sort()
+ return [x[1] for x in l[-2:]]
+
+def loss_function(output, target):
+ return F.binary_cross_entropy(output, target)
+
+def count_correct(output, target):
+ correct = 0
+ for i in range(len(output)):
+ selected = twoargmax(output[i])
+ both_correct = True
+ for selection in selected:
+ if target[i][selection] != 1:
+ both_correct = False
+ if both_correct:
+ correct += 1
+ return correct
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
+ self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
+ self.conv3 = nn.Conv2d(16, 32, 3, padding=1)
+ self.fc1 = nn.Linear(288, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 6)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, (2, 2))
+ x = F.max_pool2d(F.relu(self.conv2(x)), 2)
+ x = F.max_pool2d(F.relu(self.conv3(x)), 2)
+ x = x.view(-1, self.num_flat_features(x))
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ x = torch.sigmoid(x)
+ return x
+
+ def num_flat_features(self, x):
+ size = x.size()[1:] # all dimensions except the batch dimension
+ num_features = 1
+ for s in size:
+ num_features *= s
+ return num_features