1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
from time import time
import torch
from torch import nn
from torch import optim
from data.generate import get_single_example
from data.testset import get_testset
def do_verbose_test(model, n_tokens, seqlen, max_count):
print('verbose test:')
x, y = get_single_example(n_tokens, seqlen, max_count)
x = torch.tensor([x]).transpose(0, 1)
print('in :', x.squeeze())
print('expected out:', torch.tensor(y))
print('model out :', torch.argmax(model(x), dim=2).squeeze())
def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, device='cpu'):
torch.autograd.set_detect_anomaly(True)
model.to(device)
start_time = time()
accs = []
loss_function = nn.CrossEntropyLoss(
# weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float))
)
optimizer = optim.Adam(model.parameters(), lr=lr)
test_X, test_Y = get_testset(n_tokens, seqlen, max_count)
print('test size', test_X.shape)
for step in range(num_steps):
batch_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(batch_size)]
batch_X = torch.tensor([x[0] for x in batch_examples],
device=device
).transpose(0, 1)
batch_Y = torch.tensor([x[1] for x in batch_examples],
device=device).transpose(0, 1)
model.train()
model.zero_grad()
logits = model(batch_X)
loss = loss_function(logits.reshape(-1, max_count + 1), batch_Y.reshape(-1))
loss.backward()
optimizer.step()
if step % (num_steps//100) == 0 or step == num_steps - 1:
# Printing a summary of the current state of training every 1% of steps.
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
print('step', step, 'out of', num_steps)
print('loss train', float(loss))
print('accuracy test', float(test_acc))
do_verbose_test(model, n_tokens, seqlen, max_count)
print()
accs.append(test_acc)
print('\nTRAINING TIME:', time()-start_time)
model.eval()
return accs
|