m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/forward.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/model/forward.py b/model/forward.py
index 63ec6ea..7128be1 100644
--- a/model/forward.py
+++ b/model/forward.py
@@ -1,14 +1,15 @@
from torch import nn
+import torch.nn.functional as F
class FeedForward(nn.Module):
def __init__(self, hidden_dim, d_ff):
super(FeedForward, self).__init__()
- # TODO: implement FeedForward layer
- pass
+ self.fc1 = nn.Linear(hidden_dim, d_ff)
+ self.fc2 = nn.Linear(d_ff, hidden_dim)
def forward(self, x):
- # TODO: implement
# x shape: (seqlen, batch, hiddendim)
- result = x # placeholder
- pass
- return result
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.fc2(x)
+ return x