m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-27 19:53:53 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-27 19:53:53 +0200
commit22cd84cf78a114d75a93f75f66f6ea61c02c94ef (patch)
tree95380360b2be0f5bb18d645c9821fdd202e851a5 /model
parent8b5f31f9aeeefe647e00a2c581ca69378210f12b (diff)
Just use linear layer
Diffstat (limited to 'model')
-rw-r--r--model/attention.py20
1 files changed, 3 insertions, 17 deletions
diff --git a/model/attention.py b/model/attention.py
index 48c9e29..ffc07d3 100644
--- a/model/attention.py
+++ b/model/attention.py
@@ -2,27 +2,13 @@ import numpy as np
import torch
from torch import nn
-class Projection(nn.Module):
- """
- Linear transformation by a matrix.
- """
- def __init__(self, in_dim, out_dim):
- super(Projection, self).__init__()
- self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, ))
-
- def forward(self, x):
- """
- x shape: batch, seqlen, in_dim
- """
- return x.matmul(self.projection)
-
class Head(nn.Module):
def __init__(self, hidden_dim, output_dim):
super(Head, self).__init__()
self._scaling_factor = np.sqrt(output_dim)
- self.query_projection = Projection(hidden_dim, output_dim)
- self.key_projection = Projection(hidden_dim, output_dim)
- self.value_projection = Projection(hidden_dim, output_dim)
+ self.query_projection = nn.Linear(hidden_dim, output_dim)
+ self.key_projection = nn.Linear(hidden_dim, output_dim)
+ self.value_projection = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
"""