m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/data.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py
index ff7aca5..8019b88 100644
--- a/src/data.py
+++ b/src/data.py
@@ -141,3 +141,12 @@ def make_augmented_data(augmentations, after_target_transform=None):
datasets.append(dataset)
return torch.utils.data.ConcatDataset(datasets)
+
+def get_loaders(augmentations, target_transform, train_batch_size):
+ train_data = make_augmented_data(augmentations, after_target_transform=target_transform)
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
+
+ test_data = ShapesData(False, target_transform=target_transform)
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False)
+
+ return train_loader, test_loader