m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/data/testset.py
blob: 07fc8118b3ff0948a7b70c11c3d8ae8556d3c47f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

from data.generate import get_single_example

TEST_SIZE = 128

test_examples = [get_single_example() for i in range(TEST_SIZE)]

def get_testset(device='cpu'):
    # Transpositions are used, because the convention in PyTorch is to represent
    # sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
    test_X = torch.tensor(
        [x[0] for x in test_examples], device=device
    ).transpose(0, 1)
    test_Y = torch.tensor(
        [x[1] for x in test_examples], device=device
    ).transpose(0, 1)
    return test_X, test_Y