diff options
Diffstat (limited to 'train')
-rw-r--r-- | train/train.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/train/train.py b/train/train.py index 1d8f26e..e72fb3f 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@ from time import time +import sys import torch from torch import nn @@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d # Printing a summary of the current state of training every 1% of steps. model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) + predicted_logits = predicted_logits.to('cpu') test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) @@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() + sys.stdout.flush() accs.append(test_acc) print('\nTRAINING TIME:', time()-start_time) model.eval() |