diff options
Diffstat (limited to 'model/encoder.py')
-rw-r--r-- | model/encoder.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/model/encoder.py b/model/encoder.py index 85b3141..d6527dd 100644 --- a/model/encoder.py +++ b/model/encoder.py @@ -17,14 +17,21 @@ class EncoderModel(nn.Module): self._use_positional = use_positional self.embedding_layer = nn.Embedding(input_dim, hidden_dim) self.layers = nn.ModuleList([ - EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, - use_feedforward) for i in range(n_layers) + EncoderLayer( + hidden_dim, + d_ff, + num_heads, + use_attention, + use_feedforward, + device=device + ) for i in range(n_layers) ]) self.output_layer = nn.Linear(hidden_dim, output_dim) def forward(self, x, return_att_weights=False, verbose=False): log(f'Handling {x}', verbose) # x shape: (seqlen, batch) + x = x.to(self._device) hidden = self.embedding_layer(x) # hidden shape: (seqlen, batch, hiddendim) |