m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/runner.py')
-rw-r--r--src/runner.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/runner.py b/src/runner.py
index 511f9a5..5548a0c 100644
--- a/src/runner.py
+++ b/src/runner.py
@@ -21,12 +21,14 @@ class Runner:
self.count_correct
) = parse_file(file)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
self.train_loader, self.test_loader = data.get_loaders(
augmentations,
target_transform,
batch_size)
- self.net = Net(**net_config)
+ self.net = Net(**net_config).to(self.device)
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)
def run(self):
@@ -42,6 +44,7 @@ class Runner:
total_loss = 0
number_batches = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
+ data, target = data.to(self.device), target.to(self.device)
number_batches += 1
self.optimizer.zero_grad()
output = self.net(data)
@@ -62,6 +65,7 @@ class Runner:
number_batches = 0
with torch.no_grad():
for data, target in self.test_loader:
+ data, target = data.to(self.device), target.to(self.device)
number_batches += 1
output = self.net(data)
test_loss += self.loss_function(output, target)