diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/data.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py index ff7aca5..8019b88 100644 --- a/src/data.py +++ b/src/data.py @@ -141,3 +141,12 @@ def make_augmented_data(augmentations, after_target_transform=None): datasets.append(dataset) return torch.utils.data.ConcatDataset(datasets) + +def get_loaders(augmentations, target_transform, train_batch_size): + train_data = make_augmented_data(augmentations, after_target_transform=target_transform) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True) + + test_data = ShapesData(False, target_transform=target_transform) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False) + + return train_loader, test_loader |