import torch from data.generate import get_single_example TEST_SIZE = 128 def get_testset(n_tokens=16, seqlen=64, max_count=9, device='cpu'): test_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(TEST_SIZE)] # Transpositions are used, because the convention in PyTorch is to represent # sequence tensors as instead of . test_X = torch.tensor( [x[0] for x in test_examples], device=device ).transpose(0, 1) test_Y = torch.tensor( [x[1] for x in test_examples], device=device ).transpose(0, 1) return test_X, test_Y