blob: 4b0ba9c6ad96397482a192ff1fb4a333fe0e5796 (
plain)
1
2
3
4
5
6
7
|
import numpy as np
def get_single_example(n_tokens=16, seqlen=64, max_count=9):
seq = np.random.randint(low=0, high=n_tokens, size=(seqlen,))
label = [min(list(seq[:i]).count(x), max_count) for i, x in enumerate(seq)]
label = np.array(label)
return seq, label
|