m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/util.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/util/util.py b/util/util.py
index 102ad93..419c23a 100644
--- a/util/util.py
+++ b/util/util.py
@@ -1,9 +1,16 @@
import torch
import numpy as np
+positional_encoding = None
+
def get_positional_encoding(n_positions, n_dimensions, device='cpu'):
- # TODO: implement positional encoding
- positional_encoding = np.zeros((n_positions, n_dimensions)) # placeholder
- pass
- # output shape: (seqlen, hiddendim)
- return torch.tensor(positional_encoding, dtype=torch.float, device=device)
+ global positional_encoding
+ if positional_encoding is None:
+ numerators = torch.tensor(range(n_positions)).repeat(n_dimensions, 1).T
+ denominators = 10000 ** (torch.tensor(range(n_dimensions)) // 2 * 2 / n_dimensions)
+ print('denoms:', denominators)
+ positional_encoding = numerators / denominators
+ positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2])
+ positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2])
+ # output shape: (seqlen, hiddendim)
+ return torch.tensor(positional_encoding, dtype=torch.float, device=device)