diff options
Diffstat (limited to 'model/encoder_layer.py')
-rw-r--r-- | model/encoder_layer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/model/encoder_layer.py b/model/encoder_layer.py index 71a7d8f..311c39c 100644 --- a/model/encoder_layer.py +++ b/model/encoder_layer.py @@ -5,12 +5,12 @@ from model.forward import FeedForward class EncoderLayer(nn.Module): def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, - use_feedforward=True): + use_feedforward=True, device='cpu'): super(EncoderLayer, self).__init__() self._use_attention = use_attention self._use_feedforward = use_feedforward if use_attention: - self.attention = Attention(hidden_dim, num_heads) + self.attention = Attention(hidden_dim, num_heads, device) if use_feedforward: self.feedforward = FeedForward(hidden_dim, d_ff) |