diff options
Diffstat (limited to 'model')
-rw-r--r-- | model/attention.py | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/model/attention.py b/model/attention.py index 48c9e29..ffc07d3 100644 --- a/model/attention.py +++ b/model/attention.py @@ -2,27 +2,13 @@ import numpy as np import torch from torch import nn -class Projection(nn.Module): - """ - Linear transformation by a matrix. - """ - def __init__(self, in_dim, out_dim): - super(Projection, self).__init__() - self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, )) - - def forward(self, x): - """ - x shape: batch, seqlen, in_dim - """ - return x.matmul(self.projection) - class Head(nn.Module): def __init__(self, hidden_dim, output_dim): super(Head, self).__init__() self._scaling_factor = np.sqrt(output_dim) - self.query_projection = Projection(hidden_dim, output_dim) - self.key_projection = Projection(hidden_dim, output_dim) - self.value_projection = Projection(hidden_dim, output_dim) + self.query_projection = nn.Linear(hidden_dim, output_dim) + self.key_projection = nn.Linear(hidden_dim, output_dim) + self.value_projection = nn.Linear(hidden_dim, output_dim) def forward(self, x): """ |