m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train
diff options
context:
space:
mode:
Diffstat (limited to 'train')
-rw-r--r--train/__init__.py0
-rw-r--r--train/__main__.py5
-rw-r--r--train/train.py51
3 files changed, 56 insertions, 0 deletions
diff --git a/train/__init__.py b/train/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/train/__init__.py
diff --git a/train/__main__.py b/train/__main__.py
new file mode 100644
index 0000000..ec544f7
--- /dev/null
+++ b/train/__main__.py
@@ -0,0 +1,5 @@
+from train.train import train_model
+from model.encoder import EncoderModel
+
+model = EncoderModel(16, 64, 128, 10, 4, 4)
+train_model(model, 0.1, 100, 10)
diff --git a/train/train.py b/train/train.py
new file mode 100644
index 0000000..a88129a
--- /dev/null
+++ b/train/train.py
@@ -0,0 +1,51 @@
+from time import time
+
+import torch
+from torch import nn
+from torch import optim
+
+from data.generate import get_single_example
+from data.testset import get_testset
+
+def train_model(model, lr, num_steps, batch_size, device='cpu'):
+ model.to(device)
+
+ start_time = time()
+ accs = []
+
+ loss_function = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+
+ test_X, test_Y = get_testset()
+
+ for step in range(num_steps):
+ batch_examples = [get_single_example() for i in range(batch_size)]
+
+ batch_X = torch.tensor([x[0] for x in batch_examples],
+ device=device
+ ).transpose(0, 1)
+ batch_Y = torch.tensor([x[1] for x in batch_examples],
+ device=device).transpose(0, 1)
+
+ model.train()
+ model.zero_grad()
+ logits = model(batch_X)
+ loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1))
+ loss.backward()
+ optimizer.step()
+
+ if step % (num_steps//100) == 0 or step == num_steps - 1:
+ # Printing a summary of the current state of training every 1% of steps.
+ model.eval()
+ predicted_logits = model.forward(test_X).reshape(-1, 10)
+ test_acc = (
+ torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
+ / test_Y.reshape(-1).shape[0])
+ print('step', step, 'out of', num_steps)
+ print('loss train', float(loss))
+ print('accuracy test', float(test_acc))
+ print()
+ accs.append(test_acc)
+ print('\nTRAINING TIME:', time()-start_time)
+ model.eval()
+ return accs