diff options
-rw-r--r-- | train/__init__.py | 0 | ||||
-rw-r--r-- | train/__main__.py | 5 | ||||
-rw-r--r-- | train/train.py | 51 |
3 files changed, 56 insertions, 0 deletions
diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/train/__init__.py diff --git a/train/__main__.py b/train/__main__.py new file mode 100644 index 0000000..ec544f7 --- /dev/null +++ b/train/__main__.py @@ -0,0 +1,5 @@ +from train.train import train_model +from model.encoder import EncoderModel + +model = EncoderModel(16, 64, 128, 10, 4, 4) +train_model(model, 0.1, 100, 10) diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..a88129a --- /dev/null +++ b/train/train.py @@ -0,0 +1,51 @@ +from time import time + +import torch +from torch import nn +from torch import optim + +from data.generate import get_single_example +from data.testset import get_testset + +def train_model(model, lr, num_steps, batch_size, device='cpu'): + model.to(device) + + start_time = time() + accs = [] + + loss_function = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + test_X, test_Y = get_testset() + + for step in range(num_steps): + batch_examples = [get_single_example() for i in range(batch_size)] + + batch_X = torch.tensor([x[0] for x in batch_examples], + device=device + ).transpose(0, 1) + batch_Y = torch.tensor([x[1] for x in batch_examples], + device=device).transpose(0, 1) + + model.train() + model.zero_grad() + logits = model(batch_X) + loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1)) + loss.backward() + optimizer.step() + + if step % (num_steps//100) == 0 or step == num_steps - 1: + # Printing a summary of the current state of training every 1% of steps. + model.eval() + predicted_logits = model.forward(test_X).reshape(-1, 10) + test_acc = ( + torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) + / test_Y.reshape(-1).shape[0]) + print('step', step, 'out of', num_steps) + print('loss train', float(loss)) + print('accuracy test', float(test_acc)) + print() + accs.append(test_acc) + print('\nTRAINING TIME:', time()-start_time) + model.eval() + return accs |