import torch from torch import nn from util.util import get_positional_encoding from model.encoder_layer import EncoderLayer class EncoderModel(nn.Module): def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers, num_heads, use_attention=True, use_feedforward=True, use_positional=True, device='cpu'): super(EncoderModel, self).__init__() self._device = device self._use_positional = use_positional self.embedding_layer = nn.Embedding(input_dim, hidden_dim) self.layers = nn.ModuleList([ EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, use_feedforward) for i in range(n_layers)]) self.output_layer = nn.Linear(hidden_dim, output_dim) def forward(self, x, return_att_weights=False): # x shape: (seqlen, batch) hidden = self.embedding_layer(x) # hidden shape: (seqlen, batch, hiddendim) if self._use_positional: positional_encoding = get_positional_encoding( n_positions=hidden.shape[0], n_dimensions=hidden.shape[-1], device=self._device ) # reshaping to (seqlen, 1, hiddendim) positional_encoding = torch.reshape( positional_encoding, (hidden.shape[0], 1, hidden.shape[-1]) ) hidden = hidden + positional_encoding list_att_weights = [] for layer in self.layers: hidden, att_weights = layer(hidden) list_att_weights.append(att_weights) result = self.output_layer(hidden) if return_att_weights: return result, list_att_weights else: return result