diff options
Diffstat (limited to 'model/encoder.py')
-rw-r--r-- | model/encoder.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/model/encoder.py b/model/encoder.py new file mode 100644 index 0000000..63a5149 --- /dev/null +++ b/model/encoder.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from util.util import get_positional_encoding +from model.encoder_layer import EncoderLayer + +class EncoderModel(nn.Module): + def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers, + num_heads, use_attention=True, use_feedforward=True, + use_positional=True, device='cpu'): + super(EncoderModel, self).__init__() + self._device = device + self._use_positional = use_positional + self.embedding_layer = nn.Embedding(input_dim, hidden_dim) + self.layers = nn.ModuleList([ + EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, + use_feedforward) for i in range(n_layers)]) + self.output_layer = nn.Linear(hidden_dim, output_dim) + + def forward(self, x, return_att_weights=False): + # x shape: (seqlen, batch) + hidden = self.embedding_layer(x) + # hidden shape: (seqlen, batch, hiddendim) + + if self._use_positional: + positional_encoding = get_positional_encoding( + n_positions=hidden.shape[0], + n_dimensions=hidden.shape[-1], + device=self._device + ) + # reshaping to (seqlen, 1, hiddendim) + positional_encoding = torch.reshape( + positional_encoding, + (hidden.shape[0], 1, hidden.shape[-1]) + ) + hidden = hidden + positional_encoding + + list_att_weights = [] + for layer in self.layers: + hidden, att_weights = layer(hidden) + list_att_weights.append(att_weights) + + result = self.output_layer(hidden) + + if return_att_weights: + return result, list_att_weights + else: + return result |