from time import time import sys import torch from torch import nn from torch import optim from data.generate import get_single_example from data.testset import get_testset def do_verbose_test(model, n_tokens, seqlen, max_count): print('verbose test:') x, y = get_single_example(n_tokens, seqlen, max_count) x = torch.tensor([x]).transpose(0, 1) print('in :', x.squeeze()) print('expected out:', torch.tensor(y)) print('model out :', torch.argmax(model(x), dim=2).squeeze()) def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, device='cpu'): torch.autograd.set_detect_anomaly(True) model.to(device) start_time = time() accs = [] train_losses = [] test_losses = [] loss_function = nn.CrossEntropyLoss( # weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float)) ) optimizer = optim.Adam(model.parameters(), lr=lr) test_X, test_Y = get_testset(n_tokens, seqlen, max_count) print('test size', test_X.shape) for step in range(num_steps): batch_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(batch_size)] batch_X = torch.tensor([x[0] for x in batch_examples], device=device ).transpose(0, 1) batch_Y = torch.tensor([x[1] for x in batch_examples], device=device).transpose(0, 1) model.train() model.zero_grad() logits = model(batch_X) loss = loss_function(logits.reshape(-1, max_count + 1), batch_Y.reshape(-1)) loss.backward() optimizer.step() if step % (num_steps//100) == 0 or step == num_steps - 1: # Printing a summary of the current state of training every 1% of steps. model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) predicted_logits = predicted_logits.to('cpu') test_loss = loss_function(predicted_logits, test_Y.reshape(-1)) test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) print('step', step, 'out of', num_steps) print('loss train', float(loss)) print('loss test', float(test_loss)) print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() sys.stdout.flush() accs.append(round(float(test_acc), 2)) train_losses.append(round(float(loss.detach()), 2)) test_losses.append(round(float(test_loss.detach()), 2)) # print(accs, train_losses, test_losses) print('\nTRAINING TIME:', time()-start_time) model.eval() return train_losses, test_losses, accs