m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/attention.py
blob: d82ca17ad5edd74dcc20527930d4c3c54e779124 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch import nn

class Attention(nn.Module):
  def __init__(self, hidden_dim, num_heads):
    super(Attention, self).__init__()
    # TODO: implement Attention
    pass

  def forward(self, x):
    # TODO: implement Attention; return both result of attention mechanism and
    # attention weights (for visualization).
    # x shape: (seqlen, batch, hiddendim)
    result, att_weights = x, None # placeholder
    pass
    return result, att_weights