m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/attention.py20
1 files changed, 3 insertions, 17 deletions
diff --git a/model/attention.py b/model/attention.py
index 48c9e29..ffc07d3 100644
--- a/model/attention.py
+++ b/model/attention.py
@@ -2,27 +2,13 @@ import numpy as np
import torch
from torch import nn
-class Projection(nn.Module):
- """
- Linear transformation by a matrix.
- """
- def __init__(self, in_dim, out_dim):
- super(Projection, self).__init__()
- self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, ))
-
- def forward(self, x):
- """
- x shape: batch, seqlen, in_dim
- """
- return x.matmul(self.projection)
-
class Head(nn.Module):
def __init__(self, hidden_dim, output_dim):
super(Head, self).__init__()
self._scaling_factor = np.sqrt(output_dim)
- self.query_projection = Projection(hidden_dim, output_dim)
- self.key_projection = Projection(hidden_dim, output_dim)
- self.value_projection = Projection(hidden_dim, output_dim)
+ self.query_projection = nn.Linear(hidden_dim, output_dim)
+ self.key_projection = nn.Linear(hidden_dim, output_dim)
+ self.value_projection = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
"""