diff options
Diffstat (limited to 'data/testset.py')
-rw-r--r-- | data/testset.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/data/testset.py b/data/testset.py new file mode 100644 index 0000000..07fc811 --- /dev/null +++ b/data/testset.py @@ -0,0 +1,18 @@ +import torch + +from data.generate import get_single_example + +TEST_SIZE = 128 + +test_examples = [get_single_example() for i in range(TEST_SIZE)] + +def get_testset(device='cpu'): + # Transpositions are used, because the convention in PyTorch is to represent + # sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>. + test_X = torch.tensor( + [x[0] for x in test_examples], device=device + ).transpose(0, 1) + test_Y = torch.tensor( + [x[1] for x in test_examples], device=device + ).transpose(0, 1) + return test_X, test_Y |