blob: f8b938e1bc9c61341df68abd6bd83587be077d5c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
import torch
from data.generate import get_single_example
TEST_SIZE = 128
def get_testset(n_tokens=16, seqlen=64, max_count=9, device='cpu'):
test_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(TEST_SIZE)]
# Transpositions are used, because the convention in PyTorch is to represent
# sequence tensors as <seq_len, batch_size> instead of <batch_size, seq_len>.
test_X = torch.tensor(
[x[0] for x in test_examples], device=device
).transpose(0, 1)
test_Y = torch.tensor(
[x[1] for x in test_examples], device=device
).transpose(0, 1)
return test_X, test_Y
|