diff options
Diffstat (limited to 'src/runner.py')
-rw-r--r-- | src/runner.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/src/runner.py b/src/runner.py new file mode 100644 index 0000000..511f9a5 --- /dev/null +++ b/src/runner.py @@ -0,0 +1,77 @@ +import torch + +from parse_config import parse_file +import data +from net import Net + +class Runner: + def __init__(self, file, metrics): + self.metrics = metrics + self.setup(file) + + def setup(self, file): + ( + net_config, + lr, + self.epochs, + batch_size, + augmentations, + target_transform, + self.loss_function, + self.count_correct + ) = parse_file(file) + + self.train_loader, self.test_loader = data.get_loaders( + augmentations, + target_transform, + batch_size) + + self.net = Net(**net_config) + self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr) + + def run(self): + for epoch in range(self.epochs): + self.train_step() + self.test_step() + + def train_step(self): + """ + Performs one epoch of training. + """ + self.net.train() + total_loss = 0 + number_batches = 0 + for batch_idx, (data, target) in enumerate(self.train_loader): + number_batches += 1 + self.optimizer.zero_grad() + output = self.net(data) + loss = self.loss_function(output, target) + loss.backward() + self.optimizer.step() + total_loss += loss.detach() + if batch_idx % 10 == 0: + print('Training: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + batch_idx * len(data), len(self.train_loader.dataset), + 100. * batch_idx / len(self.train_loader), loss.item())) + self.metrics.add_train_loss(total_loss/number_batches) + + def test_step(self): + self.net.eval() + test_loss = 0 + correct = 0 + number_batches = 0 + with torch.no_grad(): + for data, target in self.test_loader: + number_batches += 1 + output = self.net(data) + test_loss += self.loss_function(output, target) + correct += self.count_correct(output, target) + + test_loss /= number_batches + accuracy = correct / len(self.test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(self.test_loader.dataset), 100. * accuracy + )) + + self.metrics.add_test_metrics(test_loss, accuracy) |