m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/data/generate.py
blob: 4b0ba9c6ad96397482a192ff1fb4a333fe0e5796 (plain)
1
2
3
4
5
6
7
import numpy as np

def get_single_example(n_tokens=16, seqlen=64, max_count=9):
  seq = np.random.randint(low=0, high=n_tokens, size=(seqlen,))
  label = [min(list(seq[:i]).count(x), max_count) for i, x in enumerate(seq)]
  label = np.array(label)
  return seq, label