m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train/train.py
blob: a88129af0ac35edabe082bf1cfb10ffde2b67534 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from time import time

import torch
from torch import nn
from torch import optim

from data.generate import get_single_example
from data.testset import get_testset

def train_model(model, lr, num_steps, batch_size, device='cpu'):
  model.to(device)

  start_time = time()
  accs = []

  loss_function = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=lr)

  test_X, test_Y = get_testset()

  for step in range(num_steps):
    batch_examples = [get_single_example() for i in range(batch_size)]

    batch_X = torch.tensor([x[0] for x in batch_examples],
                           device=device
                           ).transpose(0, 1)
    batch_Y = torch.tensor([x[1] for x in batch_examples],
                           device=device).transpose(0, 1)

    model.train()
    model.zero_grad()
    logits = model(batch_X)
    loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1))
    loss.backward()
    optimizer.step()

    if step % (num_steps//100) == 0 or step == num_steps - 1:
      # Printing a summary of the current state of training every 1% of steps.
      model.eval()
      predicted_logits = model.forward(test_X).reshape(-1, 10)
      test_acc = (
          torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
          / test_Y.reshape(-1).shape[0])
      print('step', step, 'out of', num_steps)
      print('loss train', float(loss))
      print('accuracy test', float(test_acc))
      print()
      accs.append(test_acc)
  print('\nTRAINING TIME:', time()-start_time)
  model.eval()
  return accs