m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/util/util.py
blob: bb9e0c3cf9966ac82693ff96ec48fdc2b647aa92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import numpy as np

positional_encoding = None

def get_positional_encoding(n_positions, n_dimensions, device='cpu'):
    global positional_encoding
    if positional_encoding is None:
        # Number positions from 1 instead of 0, to avoid repeated values in
        # first row of encoding
        numerators = 1 + torch.tensor(range(n_positions)).repeat(n_dimensions, 1).T
        denominators = 10000 ** (torch.tensor(range(n_dimensions)) // 2 * 2 / n_dimensions)
        positional_encoding = numerators / denominators
        positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2])
        positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2])
    # output shape: (seqlen, hiddendim)
    return positional_encoding.to(device)