1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
from time import time
import torch
from torch import nn
from torch import optim
from data.generate import get_single_example
from data.testset import get_testset
def train_model(model, lr, num_steps, batch_size, device='cpu'):
model.to(device)
start_time = time()
accs = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
test_X, test_Y = get_testset()
for step in range(num_steps):
batch_examples = [get_single_example() for i in range(batch_size)]
batch_X = torch.tensor([x[0] for x in batch_examples],
device=device
).transpose(0, 1)
batch_Y = torch.tensor([x[1] for x in batch_examples],
device=device).transpose(0, 1)
model.train()
model.zero_grad()
logits = model(batch_X)
loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1))
loss.backward()
optimizer.step()
if step % (num_steps//100) == 0 or step == num_steps - 1:
# Printing a summary of the current state of training every 1% of steps.
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, 10)
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
print('step', step, 'out of', num_steps)
print('loss train', float(loss))
print('accuracy test', float(test_acc))
print()
accs.append(test_acc)
print('\nTRAINING TIME:', time()-start_time)
model.eval()
return accs
|