From 25c01ba5092994f5156922e4873281b0a2e69b15 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Sat, 1 May 2021 19:13:36 +0200 Subject: Utilize GPU --- src/runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/runner.py b/src/runner.py index 511f9a5..5548a0c 100644 --- a/src/runner.py +++ b/src/runner.py @@ -21,12 +21,14 @@ class Runner: self.count_correct ) = parse_file(file) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.train_loader, self.test_loader = data.get_loaders( augmentations, target_transform, batch_size) - self.net = Net(**net_config) + self.net = Net(**net_config).to(self.device) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr) def run(self): @@ -42,6 +44,7 @@ class Runner: total_loss = 0 number_batches = 0 for batch_idx, (data, target) in enumerate(self.train_loader): + data, target = data.to(self.device), target.to(self.device) number_batches += 1 self.optimizer.zero_grad() output = self.net(data) @@ -62,6 +65,7 @@ class Runner: number_batches = 0 with torch.no_grad(): for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) number_batches += 1 output = self.net(data) test_loss += self.loss_function(output, target) -- cgit v1.2.3