m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/runner.py')
-rw-r--r--src/runner.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/src/runner.py b/src/runner.py
new file mode 100644
index 0000000..511f9a5
--- /dev/null
+++ b/src/runner.py
@@ -0,0 +1,77 @@
+import torch
+
+from parse_config import parse_file
+import data
+from net import Net
+
+class Runner:
+ def __init__(self, file, metrics):
+ self.metrics = metrics
+ self.setup(file)
+
+ def setup(self, file):
+ (
+ net_config,
+ lr,
+ self.epochs,
+ batch_size,
+ augmentations,
+ target_transform,
+ self.loss_function,
+ self.count_correct
+ ) = parse_file(file)
+
+ self.train_loader, self.test_loader = data.get_loaders(
+ augmentations,
+ target_transform,
+ batch_size)
+
+ self.net = Net(**net_config)
+ self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)
+
+ def run(self):
+ for epoch in range(self.epochs):
+ self.train_step()
+ self.test_step()
+
+ def train_step(self):
+ """
+ Performs one epoch of training.
+ """
+ self.net.train()
+ total_loss = 0
+ number_batches = 0
+ for batch_idx, (data, target) in enumerate(self.train_loader):
+ number_batches += 1
+ self.optimizer.zero_grad()
+ output = self.net(data)
+ loss = self.loss_function(output, target)
+ loss.backward()
+ self.optimizer.step()
+ total_loss += loss.detach()
+ if batch_idx % 10 == 0:
+ print('Training: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+ batch_idx * len(data), len(self.train_loader.dataset),
+ 100. * batch_idx / len(self.train_loader), loss.item()))
+ self.metrics.add_train_loss(total_loss/number_batches)
+
+ def test_step(self):
+ self.net.eval()
+ test_loss = 0
+ correct = 0
+ number_batches = 0
+ with torch.no_grad():
+ for data, target in self.test_loader:
+ number_batches += 1
+ output = self.net(data)
+ test_loss += self.loss_function(output, target)
+ correct += self.count_correct(output, target)
+
+ test_loss /= number_batches
+ accuracy = correct / len(self.test_loader.dataset)
+
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
+ test_loss, correct, len(self.test_loader.dataset), 100. * accuracy
+ ))
+
+ self.metrics.add_test_metrics(test_loss, accuracy)