m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/forward.py
blob: 63ec6eae05e57ba66619950b81b472ec3d8fe846 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch import nn

class FeedForward(nn.Module):
  def __init__(self, hidden_dim, d_ff):
    super(FeedForward, self).__init__()
    # TODO: implement FeedForward layer
    pass

  def forward(self, x):
    # TODO: implement
    # x shape: (seqlen, batch, hiddendim)
    result = x # placeholder
    pass
    return result