From 8ff8739b236a00169339b0b78e1f39357fdfff17 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sat, 22 May 2021 19:23:23 +0200 Subject: Generate data --- data/generate.py | 7 +++++++ data/testset.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 data/generate.py create mode 100644 data/testset.py diff --git a/data/generate.py b/data/generate.py new file mode 100644 index 0000000..4b0ba9c --- /dev/null +++ b/data/generate.py @@ -0,0 +1,7 @@ +import numpy as np + +def get_single_example(n_tokens=16, seqlen=64, max_count=9): + seq = np.random.randint(low=0, high=n_tokens, size=(seqlen,)) + label = [min(list(seq[:i]).count(x), max_count) for i, x in enumerate(seq)] + label = np.array(label) + return seq, label diff --git a/data/testset.py b/data/testset.py new file mode 100644 index 0000000..07fc811 --- /dev/null +++ b/data/testset.py @@ -0,0 +1,18 @@ +import torch + +from data.generate import get_single_example + +TEST_SIZE = 128 + +test_examples = [get_single_example() for i in range(TEST_SIZE)] + +def get_testset(device='cpu'): + # Transpositions are used, because the convention in PyTorch is to represent + # sequence tensors as instead of . + test_X = torch.tensor( + [x[0] for x in test_examples], device=device + ).transpose(0, 1) + test_Y = torch.tensor( + [x[1] for x in test_examples], device=device + ).transpose(0, 1) + return test_X, test_Y -- cgit v1.2.3