m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/data/testset.py
blob: f8b938e1bc9c61341df68abd6bd83587be077d5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

from data.generate import get_single_example

TEST_SIZE = 128

def get_testset(n_tokens=16, seqlen=64, max_count=9, device='cpu'):
    test_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(TEST_SIZE)]
    # Transpositions are used, because the convention in PyTorch is to represent
    # sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
    test_X = torch.tensor(
        [x[0] for x in test_examples], device=device
    ).transpose(0, 1)
    test_Y = torch.tensor(
        [x[1] for x in test_examples], device=device
    ).transpose(0, 1)
    return test_X, test_Y