m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
Diffstat (limited to 'model')
-rw-r--r--model/attention.py80
1 files changed, 68 insertions, 12 deletions
diff --git a/model/attention.py b/model/attention.py
index d82ca17..48c9e29 100644
--- a/model/attention.py
+++ b/model/attention.py
@@ -1,15 +1,71 @@
+import numpy as np
+import torch
from torch import nn
+class Projection(nn.Module):
+ """
+ Linear transformation by a matrix.
+ """
+ def __init__(self, in_dim, out_dim):
+ super(Projection, self).__init__()
+ self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, ))
+
+ def forward(self, x):
+ """
+ x shape: batch, seqlen, in_dim
+ """
+ return x.matmul(self.projection)
+
+class Head(nn.Module):
+ def __init__(self, hidden_dim, output_dim):
+ super(Head, self).__init__()
+ self._scaling_factor = np.sqrt(output_dim)
+ self.query_projection = Projection(hidden_dim, output_dim)
+ self.key_projection = Projection(hidden_dim, output_dim)
+ self.value_projection = Projection(hidden_dim, output_dim)
+
+ def forward(self, x):
+ """
+ x shape: seqlen, batch, hiddendim
+ """
+ # get batch in front
+ x = torch.transpose(x, 0, 1)
+
+ query = self.query_projection(x)
+ key = self.key_projection(x)
+ value = self.value_projection(x)
+
+ # transpose the matrix dimensions of key to align for multiplication
+ product = query.matmul(torch.transpose(key, 1, 2))
+ weights = torch.softmax(product / self._scaling_factor, dim=2)
+ value = weights.matmul(value)
+ value = torch.transpose(value, 0, 1)
+ weights = torch.transpose(weights, 0, 1)
+
+ return value, weights
+
class Attention(nn.Module):
- def __init__(self, hidden_dim, num_heads):
- super(Attention, self).__init__()
- # TODO: implement Attention
- pass
-
- def forward(self, x):
- # TODO: implement Attention; return both result of attention mechanism and
- # attention weights (for visualization).
- # x shape: (seqlen, batch, hiddendim)
- result, att_weights = x, None # placeholder
- pass
- return result, att_weights
+ def __init__(self, hidden_dim, num_heads):
+ super(Attention, self).__init__()
+ self._num_heads = num_heads
+ self._head_output_dim = hidden_dim // num_heads
+ # ensure hidden_dim is divisible by num_heads
+ assert(self._head_output_dim * num_heads == hidden_dim)
+ self.heads = nn.ModuleList([
+ Head(hidden_dim, self._head_output_dim) for _ in range(num_heads)
+ ])
+ self.final_projection = nn.Linear(hidden_dim, hidden_dim)
+ pass
+
+ def forward(self, x):
+ # x shape: (seqlen, batch, hiddendim)
+ result = torch.zeros(x.shape)
+ # attentions are (heads, seqlen, batch, seqlen)
+ attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0])
+ for i in range(self._num_heads):
+ from_index = i * self._head_output_dim
+ to_index = from_index + self._head_output_dim
+ result[:, :, from_index:to_index], attentions[i, :, :, :] = self.heads[i](x)
+
+ result = self.final_projection(result)
+ return result, attentions