diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:01:02 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-01 15:01:02 +0200 |
commit | 1c708e9a449a90c866d7efb44798ba8a74ee4e85 (patch) | |
tree | 5289643e5966b366c641a08ef8dc0d02259bf4fb /src | |
parent | 3602753148198a1e03cbf681da5fbd1d6a3512e5 (diff) |
Add helper to get both loaders
Diffstat (limited to 'src')
-rw-r--r-- | src/data.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py index ff7aca5..8019b88 100644 --- a/src/data.py +++ b/src/data.py @@ -141,3 +141,12 @@ def make_augmented_data(augmentations, after_target_transform=None): datasets.append(dataset) return torch.utils.data.ConcatDataset(datasets) + +def get_loaders(augmentations, target_transform, train_batch_size): + train_data = make_augmented_data(augmentations, after_target_transform=target_transform) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True) + + test_data = ShapesData(False, target_transform=target_transform) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False) + + return train_loader, test_loader |