diff options
Diffstat (limited to 'model')
-rw-r--r-- | model/forward.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/model/forward.py b/model/forward.py index 63ec6ea..7128be1 100644 --- a/model/forward.py +++ b/model/forward.py @@ -1,14 +1,15 @@ from torch import nn +import torch.nn.functional as F class FeedForward(nn.Module): def __init__(self, hidden_dim, d_ff): super(FeedForward, self).__init__() - # TODO: implement FeedForward layer - pass + self.fc1 = nn.Linear(hidden_dim, d_ff) + self.fc2 = nn.Linear(d_ff, hidden_dim) def forward(self, x): - # TODO: implement # x shape: (seqlen, batch, hiddendim) - result = x # placeholder - pass - return result + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x |