import torch.nn as nn import torch.nn.functional as F import torch def target_transform(labels): return (labels > 0) + 0. def twoargmax(a): l = list(zip(a, range(len(a)))) l.sort() return [x[1] for x in l[-2:]] def loss_function(output, target): return F.binary_cross_entropy(output, target) def count_correct(output, target): correct = 0 for i in range(len(output)): selected = twoargmax(output[i]) both_correct = True for selection in selected: if target[i][selection] != 1: both_correct = False if both_correct: correct += 1 return correct def finalizer(x): return torch.sigmoid(x) outputs = 6