m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--util/util.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/util/util.py b/util/util.py
index 419c23a..65f1838 100644
--- a/util/util.py
+++ b/util/util.py
@@ -6,11 +6,12 @@ positional_encoding = None
def get_positional_encoding(n_positions, n_dimensions, device='cpu'):
global positional_encoding
if positional_encoding is None:
- numerators = torch.tensor(range(n_positions)).repeat(n_dimensions, 1).T
+ # Number positions from 1 instead of 0, to avoid repeated values in
+ # first row of encoding
+ numerators = 1 + torch.tensor(range(n_positions)).repeat(n_dimensions, 1).T
denominators = 10000 ** (torch.tensor(range(n_dimensions)) // 2 * 2 / n_dimensions)
- print('denoms:', denominators)
positional_encoding = numerators / denominators
positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2])
positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2])
# output shape: (seqlen, hiddendim)
- return torch.tensor(positional_encoding, dtype=torch.float, device=device)
+ return positional_encoding