From 0fead7ba8062c5704b4a27c9a1c57427b6e8ecea Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Thu, 27 May 2021 21:05:36 +0200 Subject: Allow GPU use --- train/train.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'train') diff --git a/train/train.py b/train/train.py index 1d8f26e..e72fb3f 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@ from time import time +import sys import torch from torch import nn @@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d # Printing a summary of the current state of training every 1% of steps. model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) + predicted_logits = predicted_logits.to('cpu') test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) @@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() + sys.stdout.flush() accs.append(test_acc) print('\nTRAINING TIME:', time()-start_time) model.eval() -- cgit v1.2.3