m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/attention.py
blob: 75ff5a0473f563a7a9e5aba0933f977d5f6c25e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import torch
from torch import nn

class Head(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Head, self).__init__()
        self._scaling_factor = np.sqrt(output_dim)
        self.query_projection = nn.Linear(hidden_dim, output_dim)
        self.key_projection = nn.Linear(hidden_dim, output_dim)
        self.value_projection = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """
        x shape: seqlen, batch, hiddendim
        """
        # get batch in front
        x = torch.transpose(x, 0, 1)

        query = self.query_projection(x)
        key = self.key_projection(x)
        value = self.value_projection(x)

        # transpose the matrix dimensions of key to align for multiplication
        product = query.matmul(torch.transpose(key, 1, 2))
        weights = torch.softmax(product / self._scaling_factor, dim=2)
        value = weights.matmul(value)
        value = torch.transpose(value, 0, 1)
        weights = torch.transpose(weights, 0, 1)

        return value, weights

class Attention(nn.Module):
    def __init__(self, hidden_dim, num_heads, device):
        super(Attention, self).__init__()
        self._device = device
        self._num_heads = num_heads
        self._head_output_dim = hidden_dim // num_heads
        # ensure hidden_dim is divisible by num_heads
        assert(self._head_output_dim * num_heads == hidden_dim)
        self.heads = nn.ModuleList([
            Head(hidden_dim, self._head_output_dim) for _ in range(num_heads)
        ])
        self.final_projection = nn.Linear(hidden_dim, hidden_dim)
        pass

    def forward(self, x):
        # x shape: (seqlen, batch, hiddendim)
        result = torch.zeros(x.shape).to(self._device)
        # attentions are (heads, seqlen, batch, seqlen)
        attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device)
        for i in range(self._num_heads):
            from_index = i * self._head_output_dim
            to_index = from_index + self._head_output_dim
            result[:, :, from_index:to_index], attentions[i, :, :, :] = self.heads[i](x)

        result = self.final_projection(result)
        return result, attentions