m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/encoder.py')
-rw-r--r--model/encoder.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/model/encoder.py b/model/encoder.py
index 85b3141..d6527dd 100644
--- a/model/encoder.py
+++ b/model/encoder.py
@@ -17,14 +17,21 @@ class EncoderModel(nn.Module):
self._use_positional = use_positional
self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
self.layers = nn.ModuleList([
- EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
- use_feedforward) for i in range(n_layers)
+ EncoderLayer(
+ hidden_dim,
+ d_ff,
+ num_heads,
+ use_attention,
+ use_feedforward,
+ device=device
+ ) for i in range(n_layers)
])
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, x, return_att_weights=False, verbose=False):
log(f'Handling {x}', verbose)
# x shape: (seqlen, batch)
+ x = x.to(self._device)
hidden = self.embedding_layer(x)
# hidden shape: (seqlen, batch, hiddendim)