From 22cd84cf78a114d75a93f75f66f6ea61c02c94ef Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Thu, 27 May 2021 19:53:53 +0200 Subject: Just use linear layer --- model/attention.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) (limited to 'model') diff --git a/model/attention.py b/model/attention.py index 48c9e29..ffc07d3 100644 --- a/model/attention.py +++ b/model/attention.py @@ -2,27 +2,13 @@ import numpy as np import torch from torch import nn -class Projection(nn.Module): - """ - Linear transformation by a matrix. - """ - def __init__(self, in_dim, out_dim): - super(Projection, self).__init__() - self.projection = nn.Parameter(torch.normal(torch.zeros(in_dim, out_dim), 1, )) - - def forward(self, x): - """ - x shape: batch, seqlen, in_dim - """ - return x.matmul(self.projection) - class Head(nn.Module): def __init__(self, hidden_dim, output_dim): super(Head, self).__init__() self._scaling_factor = np.sqrt(output_dim) - self.query_projection = Projection(hidden_dim, output_dim) - self.key_projection = Projection(hidden_dim, output_dim) - self.value_projection = Projection(hidden_dim, output_dim) + self.query_projection = nn.Linear(hidden_dim, output_dim) + self.key_projection = nn.Linear(hidden_dim, output_dim) + self.value_projection = nn.Linear(hidden_dim, output_dim) def forward(self, x): """ -- cgit v1.2.3