From 1c363490751112b2e81276b3ce838c6f3f86a656 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Tue, 25 May 2021 17:36:38 +0200 Subject: Modify attention implementation --- model/attention.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 12 deletions(-) (limited to 'model') diff --git a/model/attention.py b/model/attention.py index d82ca17..48c9e29 100644 --- a/model/attention.py +++ b/model/attention.py @@ -1,15 +1,71 @@ +import numpy as np +import torch from torch import nn +class Projection(nn.Module): + """ + Linear transformation by a matrix. + """ + def __init__(self, in_dim, out_dim): + super(Projection, self).__init__() + self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, )) + + def forward(self, x): + """ + x shape: batch, seqlen, in_dim + """ + return x.matmul(self.projection) + +class Head(nn.Module): + def __init__(self, hidden_dim, output_dim): + super(Head, self).__init__() + self._scaling_factor = np.sqrt(output_dim) + self.query_projection = Projection(hidden_dim, output_dim) + self.key_projection = Projection(hidden_dim, output_dim) + self.value_projection = Projection(hidden_dim, output_dim) + + def forward(self, x): + """ + x shape: seqlen, batch, hiddendim + """ + # get batch in front + x = torch.transpose(x, 0, 1) + + query = self.query_projection(x) + key = self.key_projection(x) + value = self.value_projection(x) + + # transpose the matrix dimensions of key to align for multiplication + product = query.matmul(torch.transpose(key, 1, 2)) + weights = torch.softmax(product / self._scaling_factor, dim=2) + value = weights.matmul(value) + value = torch.transpose(value, 0, 1) + weights = torch.transpose(weights, 0, 1) + + return value, weights + class Attention(nn.Module): - def __init__(self, hidden_dim, num_heads): - super(Attention, self).__init__() - # TODO: implement Attention - pass - - def forward(self, x): - # TODO: implement Attention; return both result of attention mechanism and - # attention weights (for visualization). - # x shape: (seqlen, batch, hiddendim) - result, att_weights = x, None # placeholder - pass - return result, att_weights + def __init__(self, hidden_dim, num_heads): + super(Attention, self).__init__() + self._num_heads = num_heads + self._head_output_dim = hidden_dim // num_heads + # ensure hidden_dim is divisible by num_heads + assert(self._head_output_dim * num_heads == hidden_dim) + self.heads = nn.ModuleList([ + Head(hidden_dim, self._head_output_dim) for _ in range(num_heads) + ]) + self.final_projection = nn.Linear(hidden_dim, hidden_dim) + pass + + def forward(self, x): + # x shape: (seqlen, batch, hiddendim) + result = torch.zeros(x.shape) + # attentions are (heads, seqlen, batch, seqlen) + attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]) + for i in range(self._num_heads): + from_index = i * self._head_output_dim + to_index = from_index + self._head_output_dim + result[:, :, from_index:to_index], attentions[i, :, :, :] = self.heads[i](x) + + result = self.final_projection(result) + return result, attentions -- cgit v1.2.3