m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiment/experiment.py18
-rw-r--r--train/train.py11
2 files changed, 26 insertions, 3 deletions
diff --git a/experiment/experiment.py b/experiment/experiment.py
index 600732f..b06408f 100644
--- a/experiment/experiment.py
+++ b/experiment/experiment.py
@@ -2,6 +2,9 @@ import subprocess
import os
import time
+import pandas as pd
+import torch
+
from train.train import train_model
from util.parse_config import parse_file
from model.encoder import EncoderModel
@@ -16,11 +19,24 @@ class Experiment:
def run(self):
model_config, train_config = parse_file(self.file)
model = EncoderModel(device=self.device, **model_config).to(self.device)
- train_model(model, device=self.device, **train_config)
+ train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config)
+ self.save_model(model)
+ self.save_metrics(train_losses, test_losses, accuracies)
+
+ def save_model(self, model):
+ torch.save(model.state_dict(), self.dir_path('net.pt'))
def dir_path(self, file):
return '{}/{}'.format(self.dirname, file)
+ def save_metrics(self, train_losses, test_losses, accuracies):
+ data_frame = pd.DataFrame({
+ 'train_loss': train_losses,
+ 'test_loss': test_losses,
+ 'accuracy': accuracies
+ })
+ data_frame.to_csv(self.dir_path('metrics.csv'))
+
def make_dir(self, prefix):
time_string = time.strftime('%Y%m%d%H%M%S')
prefix = '' if prefix == '' else '{}-'.format(prefix)
diff --git a/train/train.py b/train/train.py
index e72fb3f..be6693c 100644
--- a/train/train.py
+++ b/train/train.py
@@ -22,6 +22,8 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
start_time = time()
accs = []
+ train_losses = []
+ test_losses = []
loss_function = nn.CrossEntropyLoss(
# weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float))
@@ -52,16 +54,21 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
predicted_logits = predicted_logits.to('cpu')
+ test_loss = loss_function(predicted_logits, test_Y.reshape(-1))
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
print('step', step, 'out of', num_steps)
print('loss train', float(loss))
+ print('loss test', float(test_loss))
print('accuracy test', float(test_acc))
do_verbose_test(model, n_tokens, seqlen, max_count)
print()
sys.stdout.flush()
- accs.append(test_acc)
+ accs.append(round(float(test_acc), 2))
+ train_losses.append(round(float(loss.detach()), 2))
+ test_losses.append(round(float(test_loss.detach()), 2))
+ # print(accs, train_losses, test_losses)
print('\nTRAINING TIME:', time()-start_time)
model.eval()
- return accs
+ return train_losses, test_losses, accs