blob: 63ec6eae05e57ba66619950b81b472ec3d8fe846 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
from torch import nn
class FeedForward(nn.Module):
def __init__(self, hidden_dim, d_ff):
super(FeedForward, self).__init__()
# TODO: implement FeedForward layer
pass
def forward(self, x):
# TODO: implement
# x shape: (seqlen, batch, hiddendim)
result = x # placeholder
pass
return result
|