import numpy as np import torch from torch import nn class Head(nn.Module): def __init__(self, hidden_dim, output_dim): super(Head, self).__init__() self._scaling_factor = np.sqrt(output_dim) self.query_projection = nn.Linear(hidden_dim, output_dim) self.key_projection = nn.Linear(hidden_dim, output_dim) self.value_projection = nn.Linear(hidden_dim, output_dim) def forward(self, x): """ x shape: seqlen, batch, hiddendim """ # get batch in front x = torch.transpose(x, 0, 1) query = self.query_projection(x) key = self.key_projection(x) value = self.value_projection(x) # transpose the matrix dimensions of key to align for multiplication product = query.matmul(torch.transpose(key, 1, 2)) weights = torch.softmax(product / self._scaling_factor, dim=2) value = weights.matmul(value) value = torch.transpose(value, 0, 1) weights = torch.transpose(weights, 0, 1) return value, weights class Attention(nn.Module): def __init__(self, hidden_dim, num_heads, device): super(Attention, self).__init__() self._device = device self._num_heads = num_heads self._head_output_dim = hidden_dim // num_heads # ensure hidden_dim is divisible by num_heads assert(self._head_output_dim * num_heads == hidden_dim) self.heads = nn.ModuleList([ Head(hidden_dim, self._head_output_dim) for _ in range(num_heads) ]) self.final_projection = nn.Linear(hidden_dim, hidden_dim) pass def forward(self, x): # x shape: (seqlen, batch, hiddendim) result = torch.zeros(x.shape).to(self._device) # attentions are (heads, seqlen, batch, seqlen) attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device) for i in range(self._num_heads): from_index = i * self._head_output_dim to_index = from_index + self._head_output_dim result[:, :, from_index:to_index], attentions[i, :, :, :] = self.heads[i](x) result = self.final_projection(result) return result, attentions