From b60ec2754313b013a29aa068bea7a55ebe00453c Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Tue, 25 May 2021 17:37:20 +0200 Subject: Implement encoder layer --- model/encoder_layer.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) (limited to 'model') diff --git a/model/encoder_layer.py b/model/encoder_layer.py index 56a3a0c..71a7d8f 100644 --- a/model/encoder_layer.py +++ b/model/encoder_layer.py @@ -1,15 +1,26 @@ from torch import nn +from model.attention import Attention +from model.forward import FeedForward + class EncoderLayer(nn.Module): - def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, - use_feedforward=True): - super(EncoderLayer, self).__init__() - # TODO: implement a single encoder layer, using Attention and FeedForward. - pass - - def forward(self, x): - # x shape: (seqlen, batch, hiddendim) - # TODO: implement a single encoder layer, using Attention and FeedForward. - result, att_weights = x, None # placeholder - pass - return result, att_weights + def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, + use_feedforward=True): + super(EncoderLayer, self).__init__() + self._use_attention = use_attention + self._use_feedforward = use_feedforward + if use_attention: + self.attention = Attention(hidden_dim, num_heads) + if use_feedforward: + self.feedforward = FeedForward(hidden_dim, d_ff) + + def forward(self, x): + weights = None + if self._use_attention: + y, weights = self.attention(x) + x = x + y + if self._use_feedforward: + y = self.feedforward(x) + x = x + y + + return x, weights -- cgit v1.2.3