m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/util/util.py
blob: 102ad93dac6d82541fa6be220e3a5a90bb801e65 (plain)
1
2
3
4
5
6
7
8
9
import torch
import numpy as np

def get_positional_encoding(n_positions, n_dimensions, device='cpu'):
  # TODO: implement positional encoding
  positional_encoding = np.zeros((n_positions, n_dimensions)) # placeholder
  pass
  # output shape: (seqlen, hiddendim)
  return torch.tensor(positional_encoding, dtype=torch.float, device=device)