diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-23 13:31:40 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-23 13:31:40 +0200 |
commit | 0b67d2c000f31c9048c2ea346448184f1ce97e0d (patch) | |
tree | ad0d6d6a5db04dbb0d430d8d44e594ede373e375 /model | |
parent | 49375838a74bc8519f7ac9ec9a983316888bdacc (diff) |
Implement feed forward layer
Diffstat (limited to 'model')
-rw-r--r-- | model/forward.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/model/forward.py b/model/forward.py index 63ec6ea..7128be1 100644 --- a/model/forward.py +++ b/model/forward.py @@ -1,14 +1,15 @@ from torch import nn +import torch.nn.functional as F class FeedForward(nn.Module): def __init__(self, hidden_dim, d_ff): super(FeedForward, self).__init__() - # TODO: implement FeedForward layer - pass + self.fc1 = nn.Linear(hidden_dim, d_ff) + self.fc2 = nn.Linear(d_ff, hidden_dim) def forward(self, x): - # TODO: implement # x shape: (seqlen, batch, hiddendim) - result = x # placeholder - pass - return result + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x |