diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-25 17:37:20 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-25 17:37:20 +0200 |
commit | b60ec2754313b013a29aa068bea7a55ebe00453c (patch) | |
tree | 69041df4e86fcbb866aa236546eda8ce7189f53a | |
parent | 1c363490751112b2e81276b3ce838c6f3f86a656 (diff) |
Implement encoder layer
-rw-r--r-- | model/encoder_layer.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/model/encoder_layer.py b/model/encoder_layer.py index 56a3a0c..71a7d8f 100644 --- a/model/encoder_layer.py +++ b/model/encoder_layer.py @@ -1,15 +1,26 @@ from torch import nn +from model.attention import Attention +from model.forward import FeedForward + class EncoderLayer(nn.Module): - def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, - use_feedforward=True): - super(EncoderLayer, self).__init__() - # TODO: implement a single encoder layer, using Attention and FeedForward. - pass - - def forward(self, x): - # x shape: (seqlen, batch, hiddendim) - # TODO: implement a single encoder layer, using Attention and FeedForward. - result, att_weights = x, None # placeholder - pass - return result, att_weights + def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, + use_feedforward=True): + super(EncoderLayer, self).__init__() + self._use_attention = use_attention + self._use_feedforward = use_feedforward + if use_attention: + self.attention = Attention(hidden_dim, num_heads) + if use_feedforward: + self.feedforward = FeedForward(hidden_dim, d_ff) + + def forward(self, x): + weights = None + if self._use_attention: + y, weights = self.attention(x) + x = x + y + if self._use_feedforward: + y = self.feedforward(x) + x = x + y + + return x, weights |