m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-25 17:36:02 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-25 17:36:02 +0200
commita2dd5a009a73d1ea5f03894320af818c05c5778d (patch)
treea53abc8405c04a0cd4f0f31401f0cc7cff7dbdd7 /data
parent0b67d2c000f31c9048c2ea346448184f1ce97e0d (diff)
Allow for configurable test set parameters
Diffstat (limited to 'data')
-rw-r--r--data/testset.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/data/testset.py b/data/testset.py
index 07fc811..f8b938e 100644
--- a/data/testset.py
+++ b/data/testset.py
@@ -4,9 +4,8 @@ from data.generate import get_single_example
TEST_SIZE = 128
-test_examples = [get_single_example() for i in range(TEST_SIZE)]
-
-def get_testset(device='cpu'):
+def get_testset(n_tokens=16, seqlen=64, max_count=9, device='cpu'):
+ test_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(TEST_SIZE)]
# Transpositions are used, because the convention in PyTorch is to represent
# sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
test_X = torch.tensor(