m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train/train.py')
-rw-r--r--train/train.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/train/train.py b/train/train.py
index 1d8f26e..e72fb3f 100644
--- a/train/train.py
+++ b/train/train.py
@@ -1,4 +1,5 @@
from time import time
+import sys
import torch
from torch import nn
@@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
# Printing a summary of the current state of training every 1% of steps.
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
+ predicted_logits = predicted_logits.to('cpu')
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
@@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
print('accuracy test', float(test_acc))
do_verbose_test(model, n_tokens, seqlen, max_count)
print()
+ sys.stdout.flush()
accs.append(test_acc)
print('\nTRAINING TIME:', time()-start_time)
model.eval()