From a2dd5a009a73d1ea5f03894320af818c05c5778d Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Tue, 25 May 2021 17:36:02 +0200 Subject: Allow for configurable test set parameters --- data/testset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/data/testset.py b/data/testset.py index 07fc811..f8b938e 100644 --- a/data/testset.py +++ b/data/testset.py @@ -4,9 +4,8 @@ from data.generate import get_single_example TEST_SIZE = 128 -test_examples = [get_single_example() for i in range(TEST_SIZE)] - -def get_testset(device='cpu'): +def get_testset(n_tokens=16, seqlen=64, max_count=9, device='cpu'): + test_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(TEST_SIZE)] # Transpositions are used, because the convention in PyTorch is to represent # sequence tensors as instead of . test_X = torch.tensor( -- cgit v1.2.3