diff options
Diffstat (limited to 'data/generate.py')
-rw-r--r-- | data/generate.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/data/generate.py b/data/generate.py new file mode 100644 index 0000000..4b0ba9c --- /dev/null +++ b/data/generate.py @@ -0,0 +1,7 @@ +import numpy as np + +def get_single_example(n_tokens=16, seqlen=64, max_count=9): + seq = np.random.randint(low=0, high=n_tokens, size=(seqlen,)) + label = [min(list(seq[:i]).count(x), max_count) for i, x in enumerate(seq)] + label = np.array(label) + return seq, label |