blob: 07fc8118b3ff0948a7b70c11c3d8ae8556d3c47f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import torch
from data.generate import get_single_example
TEST_SIZE = 128
test_examples = [get_single_example() for i in range(TEST_SIZE)]
def get_testset(device='cpu'):
# Transpositions are used, because the convention in PyTorch is to represent
# sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
test_X = torch.tensor(
[x[0] for x in test_examples], device=device
).transpose(0, 1)
test_Y = torch.tensor(
[x[1] for x in test_examples], device=device
).transpose(0, 1)
return test_X, test_Y
|