m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/encoder_layer.py
blob: 56a3a0c82ca3b81fdb894195eec46a1c2943d54b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch import nn

class EncoderLayer(nn.Module):
  def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
               use_feedforward=True):
    super(EncoderLayer, self).__init__()
    # TODO: implement a single encoder layer, using Attention and FeedForward.
    pass

  def forward(self, x):
    # x shape: (seqlen, batch, hiddendim)
    # TODO: implement a single encoder layer, using Attention and FeedForward.
    result, att_weights = x, None # placeholder
    pass
    return result, att_weights