m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 19:13:36 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 19:13:36 +0200
commit25c01ba5092994f5156922e4873281b0a2e69b15 (patch)
tree60fad3e3d91481dc411b0b8ba7004ec8405a3b40
parentc2f5de3119be4a22474a1c4e874b7e635ab4160c (diff)
Utilize GPU
-rw-r--r--src/runner.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/runner.py b/src/runner.py
index 511f9a5..5548a0c 100644
--- a/src/runner.py
+++ b/src/runner.py
@@ -21,12 +21,14 @@ class Runner:
self.count_correct
) = parse_file(file)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
self.train_loader, self.test_loader = data.get_loaders(
augmentations,
target_transform,
batch_size)
- self.net = Net(**net_config)
+ self.net = Net(**net_config).to(self.device)
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)
def run(self):
@@ -42,6 +44,7 @@ class Runner:
total_loss = 0
number_batches = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
+ data, target = data.to(self.device), target.to(self.device)
number_batches += 1
self.optimizer.zero_grad()
output = self.net(data)
@@ -62,6 +65,7 @@ class Runner:
number_batches = 0
with torch.no_grad():
for data, target in self.test_loader:
+ data, target = data.to(self.device), target.to(self.device)
number_batches += 1
output = self.net(data)
test_loss += self.loss_function(output, target)