diff options
author | Marcin Chrzanowski <mc370754@students.mimuw.edu.pl> | 2021-05-27 21:05:36 +0200 |
---|---|---|
committer | Marcin Chrzanowski <mc370754@students.mimuw.edu.pl> | 2021-05-27 21:05:36 +0200 |
commit | 0fead7ba8062c5704b4a27c9a1c57427b6e8ecea (patch) | |
tree | 495814c1070aba7fdead59b8473b89c92aa92feb /train/train.py | |
parent | 0226b13c96e048282cc1d1868eaeb59fd89877b3 (diff) |
Allow GPU use
Diffstat (limited to 'train/train.py')
-rw-r--r-- | train/train.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/train/train.py b/train/train.py index 1d8f26e..e72fb3f 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@ from time import time +import sys import torch from torch import nn @@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d # Printing a summary of the current state of training every 1% of steps. model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) + predicted_logits = predicted_logits.to('cpu') test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) @@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() + sys.stdout.flush() accs.append(test_acc) print('\nTRAINING TIME:', time()-start_time) model.eval() |