from torch import nn import torch.nn.functional as F class FeedForward(nn.Module): def __init__(self, hidden_dim, d_ff): super(FeedForward, self).__init__() self.fc1 = nn.Linear(hidden_dim, d_ff) self.fc2 = nn.Linear(d_ff, hidden_dim) def forward(self, x): # x shape: (seqlen, batch, hiddendim) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) return x