diff options
Diffstat (limited to 'model')
-rw-r--r-- | model/encoder_layer.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/model/encoder_layer.py b/model/encoder_layer.py index 56a3a0c..71a7d8f 100644 --- a/model/encoder_layer.py +++ b/model/encoder_layer.py @@ -1,15 +1,26 @@ from torch import nn +from model.attention import Attention +from model.forward import FeedForward + class EncoderLayer(nn.Module): - def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, - use_feedforward=True): - super(EncoderLayer, self).__init__() - # TODO: implement a single encoder layer, using Attention and FeedForward. - pass - - def forward(self, x): - # x shape: (seqlen, batch, hiddendim) - # TODO: implement a single encoder layer, using Attention and FeedForward. - result, att_weights = x, None # placeholder - pass - return result, att_weights + def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, + use_feedforward=True): + super(EncoderLayer, self).__init__() + self._use_attention = use_attention + self._use_feedforward = use_feedforward + if use_attention: + self.attention = Attention(hidden_dim, num_heads) + if use_feedforward: + self.feedforward = FeedForward(hidden_dim, d_ff) + + def forward(self, x): + weights = None + if self._use_attention: + y, weights = self.attention(x) + x = x + y + if self._use_feedforward: + y = self.feedforward(x) + x = x + y + + return x, weights |