diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-29 14:14:33 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-29 14:14:33 +0200 |
commit | a4e33358691431575d169a6102a945e16d132a44 (patch) | |
tree | e1b8b1f20ea238d8dde11b344457a8837bcb372f /train | |
parent | 0fead7ba8062c5704b4a27c9a1c57427b6e8ecea (diff) |
Save model and metrics
Diffstat (limited to 'train')
-rw-r--r-- | train/train.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/train/train.py b/train/train.py index e72fb3f..be6693c 100644 --- a/train/train.py +++ b/train/train.py @@ -22,6 +22,8 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d start_time = time() accs = [] + train_losses = [] + test_losses = [] loss_function = nn.CrossEntropyLoss( # weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float)) @@ -52,16 +54,21 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) predicted_logits = predicted_logits.to('cpu') + test_loss = loss_function(predicted_logits, test_Y.reshape(-1)) test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) print('step', step, 'out of', num_steps) print('loss train', float(loss)) + print('loss test', float(test_loss)) print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() sys.stdout.flush() - accs.append(test_acc) + accs.append(round(float(test_acc), 2)) + train_losses.append(round(float(loss.detach()), 2)) + test_losses.append(round(float(test_loss.detach()), 2)) + # print(accs, train_losses, test_losses) print('\nTRAINING TIME:', time()-start_time) model.eval() - return accs + return train_losses, test_losses, accs |