import numpy as np def get_single_example(n_tokens=16, seqlen=64, max_count=9): seq = np.random.randint(low=0, high=n_tokens, size=(seqlen,)) label = [min(list(seq[:i]).count(x), max_count) for i, x in enumerate(seq)] label = np.array(label) return seq, label