from torch import nn from model.attention import Attention from model.forward import FeedForward class EncoderLayer(nn.Module): def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, use_feedforward=True): super(EncoderLayer, self).__init__() self._use_attention = use_attention self._use_feedforward = use_feedforward if use_attention: self.attention = Attention(hidden_dim, num_heads) if use_feedforward: self.feedforward = FeedForward(hidden_dim, d_ff) def forward(self, x): weights = None if self._use_attention: y, weights = self.attention(x) x = x + y if self._use_feedforward: y = self.feedforward(x) x = x + y return x, weights