1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
import numpy as np
import torch
from torch import nn
class Projection(nn.Module):
"""
Linear transformation by a matrix.
"""
def __init__(self, in_dim, out_dim):
super(Projection, self).__init__()
self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, ))
def forward(self, x):
"""
x shape: batch, seqlen, in_dim
"""
return x.matmul(self.projection)
class Head(nn.Module):
def __init__(self, hidden_dim, output_dim):
super(Head, self).__init__()
self._scaling_factor = np.sqrt(output_dim)
self.query_projection = Projection(hidden_dim, output_dim)
self.key_projection = Projection(hidden_dim, output_dim)
self.value_projection = Projection(hidden_dim, output_dim)
def forward(self, x):
"""
x shape: seqlen, batch, hiddendim
"""
# get batch in front
x = torch.transpose(x, 0, 1)
query = self.query_projection(x)
key = self.key_projection(x)
value = self.value_projection(x)
# transpose the matrix dimensions of key to align for multiplication
product = query.matmul(torch.transpose(key, 1, 2))
weights = torch.softmax(product / self._scaling_factor, dim=2)
value = weights.matmul(value)
value = torch.transpose(value, 0, 1)
weights = torch.transpose(weights, 0, 1)
return value, weights
class Attention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(Attention, self).__init__()
self._num_heads = num_heads
self._head_output_dim = hidden_dim // num_heads
# ensure hidden_dim is divisible by num_heads
assert(self._head_output_dim * num_heads == hidden_dim)
self.heads = nn.ModuleList([
Head(hidden_dim, self._head_output_dim) for _ in range(num_heads)
])
self.final_projection = nn.Linear(hidden_dim, hidden_dim)
pass
def forward(self, x):
# x shape: (seqlen, batch, hiddendim)
result = torch.zeros(x.shape)
# attentions are (heads, seqlen, batch, seqlen)
attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0])
for i in range(self._num_heads):
from_index = i * self._head_output_dim
to_index = from_index + self._head_output_dim
result[:, :, from_index:to_index], attentions[i, :, :, :] = self.heads[i](x)
result = self.final_projection(result)
return result, attentions
|