diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:53:53 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:53:53 +0200 |
commit | 22cd84cf78a114d75a93f75f66f6ea61c02c94ef (patch) | |
tree | 95380360b2be0f5bb18d645c9821fdd202e851a5 /model | |
parent | 8b5f31f9aeeefe647e00a2c581ca69378210f12b (diff) |
Just use linear layer
Diffstat (limited to 'model')
-rw-r--r-- | model/attention.py | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/model/attention.py b/model/attention.py index 48c9e29..ffc07d3 100644 --- a/model/attention.py +++ b/model/attention.py @@ -2,27 +2,13 @@ import numpy as np import torch from torch import nn -class Projection(nn.Module): - """ - Linear transformation by a matrix. - """ - def __init__(self, in_dim, out_dim): - super(Projection, self).__init__() - self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, )) - - def forward(self, x): - """ - x shape: batch, seqlen, in_dim - """ - return x.matmul(self.projection) - class Head(nn.Module): def __init__(self, hidden_dim, output_dim): super(Head, self).__init__() self._scaling_factor = np.sqrt(output_dim) - self.query_projection = Projection(hidden_dim, output_dim) - self.key_projection = Projection(hidden_dim, output_dim) - self.value_projection = Projection(hidden_dim, output_dim) + self.query_projection = nn.Linear(hidden_dim, output_dim) + self.key_projection = nn.Linear(hidden_dim, output_dim) + self.value_projection = nn.Linear(hidden_dim, output_dim) def forward(self, x): """ |