m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/encoder.py
diff options
context:
space:
mode:
authorMarcin Chrzanowski <mc370754@students.mimuw.edu.pl>2021-05-27 21:05:36 +0200
committerMarcin Chrzanowski <mc370754@students.mimuw.edu.pl>2021-05-27 21:05:36 +0200
commit0fead7ba8062c5704b4a27c9a1c57427b6e8ecea (patch)
tree495814c1070aba7fdead59b8473b89c92aa92feb /model/encoder.py
parent0226b13c96e048282cc1d1868eaeb59fd89877b3 (diff)
Allow GPU use
Diffstat (limited to 'model/encoder.py')
-rw-r--r--model/encoder.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/model/encoder.py b/model/encoder.py
index 85b3141..d6527dd 100644
--- a/model/encoder.py
+++ b/model/encoder.py
@@ -17,14 +17,21 @@ class EncoderModel(nn.Module):
self._use_positional = use_positional
self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
self.layers = nn.ModuleList([
- EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
- use_feedforward) for i in range(n_layers)
+ EncoderLayer(
+ hidden_dim,
+ d_ff,
+ num_heads,
+ use_attention,
+ use_feedforward,
+ device=device
+ ) for i in range(n_layers)
])
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, x, return_att_weights=False, verbose=False):
log(f'Handling {x}', verbose)
# x shape: (seqlen, batch)
+ x = x.to(self._device)
hidden = self.embedding_layer(x)
# hidden shape: (seqlen, batch, hiddendim)