import torch import numpy as np positional_encoding = None def get_positional_encoding(n_positions, n_dimensions, device='cpu'): global positional_encoding if positional_encoding is None: # Number positions from 1 instead of 0, to avoid repeated values in # first row of encoding numerators = 1 + torch.tensor(range(n_positions)).repeat(n_dimensions, 1).T denominators = 10000 ** (torch.tensor(range(n_dimensions)) // 2 * 2 / n_dimensions) positional_encoding = numerators / denominators positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2]) positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2]) # output shape: (seqlen, hiddendim) return positional_encoding.to(device)