m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/attention.py')
-rw-r--r--model/attention.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/model/attention.py b/model/attention.py
index ffc07d3..75ff5a0 100644
--- a/model/attention.py
+++ b/model/attention.py
@@ -31,8 +31,9 @@ class Head(nn.Module):
return value, weights
class Attention(nn.Module):
- def __init__(self, hidden_dim, num_heads):
+ def __init__(self, hidden_dim, num_heads, device):
super(Attention, self).__init__()
+ self._device = device
self._num_heads = num_heads
self._head_output_dim = hidden_dim // num_heads
# ensure hidden_dim is divisible by num_heads
@@ -45,9 +46,9 @@ class Attention(nn.Module):
def forward(self, x):
# x shape: (seqlen, batch, hiddendim)
- result = torch.zeros(x.shape)
+ result = torch.zeros(x.shape).to(self._device)
# attentions are (heads, seqlen, batch, seqlen)
- attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0])
+ attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device)
for i in range(self._num_heads):
from_index = i * self._head_output_dim
to_index = from_index + self._head_output_dim