From 0b67d2c000f31c9048c2ea346448184f1ce97e0d Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sun, 23 May 2021 13:31:40 +0200 Subject: Implement feed forward layer --- model/forward.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'model') diff --git a/model/forward.py b/model/forward.py index 63ec6ea..7128be1 100644 --- a/model/forward.py +++ b/model/forward.py @@ -1,14 +1,15 @@ from torch import nn +import torch.nn.functional as F class FeedForward(nn.Module): def __init__(self, hidden_dim, d_ff): super(FeedForward, self).__init__() - # TODO: implement FeedForward layer - pass + self.fc1 = nn.Linear(hidden_dim, d_ff) + self.fc2 = nn.Linear(d_ff, hidden_dim) def forward(self, x): - # TODO: implement # x shape: (seqlen, batch, hiddendim) - result = x # placeholder - pass - return result + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x -- cgit v1.2.3