m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train/train.py
blob: 1d8f26eeb88812e42f58c7f4cf28272237c5c798 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from time import time

import torch
from torch import nn
from torch import optim

from data.generate import get_single_example
from data.testset import get_testset

def do_verbose_test(model, n_tokens, seqlen, max_count):
    print('verbose test:')
    x, y = get_single_example(n_tokens, seqlen, max_count)
    x = torch.tensor([x]).transpose(0, 1)
    print('in          :', x.squeeze())
    print('expected out:', torch.tensor(y))
    print('model out   :', torch.argmax(model(x), dim=2).squeeze())

def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, device='cpu'):
    torch.autograd.set_detect_anomaly(True)
    model.to(device)

    start_time = time()
    accs = []

    loss_function = nn.CrossEntropyLoss(
        # weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float))
    )
    optimizer = optim.Adam(model.parameters(), lr=lr)

    test_X, test_Y = get_testset(n_tokens, seqlen, max_count)
    print('test size', test_X.shape)

    for step in range(num_steps):
        batch_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(batch_size)]

        batch_X = torch.tensor([x[0] for x in batch_examples],
                                                     device=device
                                                     ).transpose(0, 1)
        batch_Y = torch.tensor([x[1] for x in batch_examples],
                                                     device=device).transpose(0, 1)

        model.train()
        model.zero_grad()
        logits = model(batch_X)
        loss = loss_function(logits.reshape(-1, max_count + 1), batch_Y.reshape(-1))
        loss.backward()
        optimizer.step()

        if step % (num_steps//100) == 0 or step == num_steps - 1:
            # Printing a summary of the current state of training every 1% of steps.
            model.eval()
            predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
            test_acc = (
                    torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
                    / test_Y.reshape(-1).shape[0])
            print('step', step, 'out of', num_steps)
            print('loss train', float(loss))
            print('accuracy test', float(test_acc))
            do_verbose_test(model, n_tokens, seqlen, max_count)
            print()
            accs.append(test_acc)
    print('\nTRAINING TIME:', time()-start_time)
    model.eval()
    return accs