blob: d82ca17ad5edd74dcc20527930d4c3c54e779124 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
from torch import nn
class Attention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(Attention, self).__init__()
# TODO: implement Attention
pass
def forward(self, x):
# TODO: implement Attention; return both result of attention mechanism and
# attention weights (for visualization).
# x shape: (seqlen, batch, hiddendim)
result, att_weights = x, None # placeholder
pass
return result, att_weights
|