diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 19:13:36 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 19:13:36 +0200 |
commit | 25c01ba5092994f5156922e4873281b0a2e69b15 (patch) | |
tree | 60fad3e3d91481dc411b0b8ba7004ec8405a3b40 /src | |
parent | c2f5de3119be4a22474a1c4e874b7e635ab4160c (diff) |
Utilize GPU
Diffstat (limited to 'src')
-rw-r--r-- | src/runner.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/src/runner.py b/src/runner.py index 511f9a5..5548a0c 100644 --- a/src/runner.py +++ b/src/runner.py @@ -21,12 +21,14 @@ class Runner: self.count_correct ) = parse_file(file) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.train_loader, self.test_loader = data.get_loaders( augmentations, target_transform, batch_size) - self.net = Net(**net_config) + self.net = Net(**net_config).to(self.device) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr) def run(self): @@ -42,6 +44,7 @@ class Runner: total_loss = 0 number_batches = 0 for batch_idx, (data, target) in enumerate(self.train_loader): + data, target = data.to(self.device), target.to(self.device) number_batches += 1 self.optimizer.zero_grad() output = self.net(data) @@ -62,6 +65,7 @@ class Runner: number_batches = 0 with torch.no_grad(): for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) number_batches += 1 output = self.net(data) test_loss += self.loss_function(output, target) |