m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/data/testset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/testset.py')
-rw-r--r--data/testset.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/data/testset.py b/data/testset.py
new file mode 100644
index 0000000..07fc811
--- /dev/null
+++ b/data/testset.py
@@ -0,0 +1,18 @@
+import torch
+
+from data.generate import get_single_example
+
+TEST_SIZE = 128
+
+test_examples = [get_single_example() for i in range(TEST_SIZE)]
+
+def get_testset(device='cpu'):
+ # Transpositions are used, because the convention in PyTorch is to represent
+ # sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
+ test_X = torch.tensor(
+ [x[0] for x in test_examples], device=device
+ ).transpose(0, 1)
+ test_Y = torch.tensor(
+ [x[1] for x in test_examples], device=device
+ ).transpose(0, 1)
+ return test_X, test_Y