diff options
Diffstat (limited to 'model/attention.py')
-rw-r--r-- | model/attention.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/model/attention.py b/model/attention.py index ffc07d3..75ff5a0 100644 --- a/model/attention.py +++ b/model/attention.py @@ -31,8 +31,9 @@ class Head(nn.Module): return value, weights class Attention(nn.Module): - def __init__(self, hidden_dim, num_heads): + def __init__(self, hidden_dim, num_heads, device): super(Attention, self).__init__() + self._device = device self._num_heads = num_heads self._head_output_dim = hidden_dim // num_heads # ensure hidden_dim is divisible by num_heads @@ -45,9 +46,9 @@ class Attention(nn.Module): def forward(self, x): # x shape: (seqlen, batch, hiddendim) - result = torch.zeros(x.shape) + result = torch.zeros(x.shape).to(self._device) # attentions are (heads, seqlen, batch, seqlen) - attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]) + attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device) for i in range(self._num_heads): from_index = i * self._head_output_dim to_index = from_index + self._head_output_dim |