m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-29 14:14:33 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-29 14:14:33 +0200
commita4e33358691431575d169a6102a945e16d132a44 (patch)
treee1b8b1f20ea238d8dde11b344457a8837bcb372f /train
parent0fead7ba8062c5704b4a27c9a1c57427b6e8ecea (diff)
Save model and metrics
Diffstat (limited to 'train')
-rw-r--r--train/train.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/train/train.py b/train/train.py
index e72fb3f..be6693c 100644
--- a/train/train.py
+++ b/train/train.py
@@ -22,6 +22,8 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
start_time = time()
accs = []
+ train_losses = []
+ test_losses = []
loss_function = nn.CrossEntropyLoss(
# weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float))
@@ -52,16 +54,21 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
predicted_logits = predicted_logits.to('cpu')
+ test_loss = loss_function(predicted_logits, test_Y.reshape(-1))
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
print('step', step, 'out of', num_steps)
print('loss train', float(loss))
+ print('loss test', float(test_loss))
print('accuracy test', float(test_acc))
do_verbose_test(model, n_tokens, seqlen, max_count)
print()
sys.stdout.flush()
- accs.append(test_acc)
+ accs.append(round(float(test_acc), 2))
+ train_losses.append(round(float(loss.detach()), 2))
+ test_losses.append(round(float(test_loss.detach()), 2))
+ # print(accs, train_losses, test_losses)
print('\nTRAINING TIME:', time()-start_time)
model.eval()
- return accs
+ return train_losses, test_losses, accs