m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-23 13:31:40 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-23 13:31:40 +0200
commit0b67d2c000f31c9048c2ea346448184f1ce97e0d (patch)
treead0d6d6a5db04dbb0d430d8d44e594ede373e375
parent49375838a74bc8519f7ac9ec9a983316888bdacc (diff)
Implement feed forward layer
-rw-r--r--model/forward.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/model/forward.py b/model/forward.py
index 63ec6ea..7128be1 100644
--- a/model/forward.py
+++ b/model/forward.py
@@ -1,14 +1,15 @@
from torch import nn
+import torch.nn.functional as F
class FeedForward(nn.Module):
def __init__(self, hidden_dim, d_ff):
super(FeedForward, self).__init__()
- # TODO: implement FeedForward layer
- pass
+ self.fc1 = nn.Linear(hidden_dim, d_ff)
+ self.fc2 = nn.Linear(d_ff, hidden_dim)
def forward(self, x):
- # TODO: implement
# x shape: (seqlen, batch, hiddendim)
- result = x # placeholder
- pass
- return result
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.fc2(x)
+ return x