blob: 7a93e7fe079b8f362d8f0a0cc46158286c9a667d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
import torch.nn as nn
import torch.nn.functional as F
import torch
def target_transform(labels):
return (labels > 0) + 0.
def twoargmax(a):
l = list(zip(a, range(len(a))))
l.sort()
return [x[1] for x in l[-2:]]
def loss_function(output, target):
return F.binary_cross_entropy(output, target)
def count_correct(output, target):
correct = 0
for i in range(len(output)):
selected = twoargmax(output[i])
both_correct = True
for selection in selected:
if target[i][selection] != 1:
both_correct = False
if both_correct:
correct += 1
return correct
def finalizer(x):
return torch.sigmoid(x)
outputs = 6
|