diff options
Diffstat (limited to 'train')
| -rw-r--r-- | train/train.py | 3 | 
1 files changed, 3 insertions, 0 deletions
| diff --git a/train/train.py b/train/train.py index 1d8f26e..e72fb3f 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@  from time import time +import sys  import torch  from torch import nn @@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d              # Printing a summary of the current state of training every 1% of steps.              model.eval()              predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) +            predicted_logits = predicted_logits.to('cpu')              test_acc = (                      torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))                      / test_Y.reshape(-1).shape[0]) @@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d              print('accuracy test', float(test_acc))              do_verbose_test(model, n_tokens, seqlen, max_count)              print() +            sys.stdout.flush()              accs.append(test_acc)      print('\nTRAINING TIME:', time()-start_time)      model.eval() |