m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 15:01:02 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-01 15:01:02 +0200
commit1c708e9a449a90c866d7efb44798ba8a74ee4e85 (patch)
tree5289643e5966b366c641a08ef8dc0d02259bf4fb /src
parent3602753148198a1e03cbf681da5fbd1d6a3512e5 (diff)
Add helper to get both loaders
Diffstat (limited to 'src')
-rw-r--r--src/data.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py
index ff7aca5..8019b88 100644
--- a/src/data.py
+++ b/src/data.py
@@ -141,3 +141,12 @@ def make_augmented_data(augmentations, after_target_transform=None):
datasets.append(dataset)
return torch.utils.data.ConcatDataset(datasets)
+
+def get_loaders(augmentations, target_transform, train_batch_size):
+ train_data = make_augmented_data(augmentations, after_target_transform=target_transform)
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
+
+ test_data = ShapesData(False, target_transform=target_transform)
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False)
+
+ return train_loader, test_loader