import torch from parse_config import parse_file import data from net import Net class Runner: def __init__(self, file, metrics): self.metrics = metrics self.setup(file) def setup(self, file): ( net_config, lr, self.epochs, batch_size, augmentations, target_transform, self.loss_function, self.count_correct ) = parse_file(file) self.train_loader, self.test_loader = data.get_loaders( augmentations, target_transform, batch_size) self.net = Net(**net_config) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr) def run(self): for epoch in range(self.epochs): self.train_step() self.test_step() def train_step(self): """ Performs one epoch of training. """ self.net.train() total_loss = 0 number_batches = 0 for batch_idx, (data, target) in enumerate(self.train_loader): number_batches += 1 self.optimizer.zero_grad() output = self.net(data) loss = self.loss_function(output, target) loss.backward() self.optimizer.step() total_loss += loss.detach() if batch_idx % 10 == 0: print('Training: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), loss.item())) self.metrics.add_train_loss(total_loss/number_batches) def test_step(self): self.net.eval() test_loss = 0 correct = 0 number_batches = 0 with torch.no_grad(): for data, target in self.test_loader: number_batches += 1 output = self.net(data) test_loss += self.loss_function(output, target) correct += self.count_correct(output, target) test_loss /= number_batches accuracy = correct / len(self.test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(self.test_loader.dataset), 100. * accuracy )) self.metrics.add_test_metrics(test_loss, accuracy)