import os import numpy as np import pandas as pd import torch import torchvision import torchvision.transforms.functional as F class ShapesData(torch.utils.data.Dataset): def __init__(self, train, transform=None, target_transform=None): self.train = train self.transform = transform self.target_transform = target_transform labels_file = './data/labels.csv' self.labels = pd.read_csv(labels_file) def __len__(self): if self.train: return 9000 else: return 1000 def get_index(self, i): if self.train: assert i < 9000, 'Train dataset index out of bounds' return i else: assert i < 1000, 'Test dataset index out of bounds' return i + 9000 def __getitem__(self, i): i = self.get_index(i) filename = os.path.join('./data', self.labels.loc[i, 'name']) image = torchvision.io.read_image(filename) # Images as we read them have four channels, each with values in (0, # 255) Take the first channel and rescale to (-1, 1) image = (image[0] / 255 - 0.5) * 2 if self.transform: image = self.transform(image) labels = self.labels.loc[i, 'squares':'left'] labels = torch.tensor(labels.values.astype(np.float32)) if self.target_transform: labels = self.target_transform(labels) sample = [ image, labels ] return sample def rotation_transform(image, rotation): return F.rotate(image.unsqueeze(0), angle=(rotation * 90)) def rotation_target_transform(labels, rotation): labels[2:6] = labels[2:6][[*range(rotation, 4), *range(0, rotation)]] return labels def make_rotation(rotation): return torchvision.transforms.Lambda( lambda image: rotation_transform(image, rotation) ), torchvision.transforms.Lambda( lambda labels: rotation_target_transform(labels, rotation) ) def vertical_flip_target_transform(labels): labels[[2, 4]] = labels[[4, 2]] return labels def make_vertical_flip(): return torchvision.transforms.Lambda( F.vflip ), torchvision.transforms.Lambda( vertical_flip_target_transform ) def horizontal_flip_target_transform(labels): labels[[3, 5]] = labels[[5, 3]] return labels def make_horizontal_flip(): return torchvision.transforms.Lambda( F.hflip ), torchvision.transforms.Lambda( horizontal_flip_target_transform ) def make_transforms(rotation, flip_vertical, flip_horizontal): """ Returns an image and label transform for rotating and/or flipping images. - rotation: 0-3, indicating number of counter clockwise 90° turns - flip_vertical, flip_horizontal: booleans """ transforms = [] target_transforms = [] if rotation > 0: rotation, target_rotation = make_rotation(rotation) transforms.append(rotation) target_transforms.append(target_rotation) if flip_vertical: flip, target_flip = make_vertical_flip() transforms.append(flip) target_transforms.append(target_flip) if flip_horizontal: flip, target_flip = make_horizontal_flip() transforms.append(flip) target_transforms.append(target_flip) if transforms: # Transformations add another dimension, which screws up the net transforms.append(torchvision.transforms.Lambda(lambda image: image.squeeze() )) return (torchvision.transforms.Compose(transforms), torchvision.transforms.Compose(target_transforms)) def make_augmented_data(augmentations, after_target_transform=None): """ Returns concatanated ShapesData train datasets, each modified by an augmentation. - augmentations: list of triples (rotation, vertical flip, horizontal flip). - after_target_transform: an additional transform to apply after the augmentation ones. """ datasets = [] for rotation, vertical_flip, horizontal_flip in augmentations: transform, target_transform = make_transforms( rotation, vertical_flip, horizontal_flip ) if after_target_transform: target_transform = torchvision.transforms.Compose([ target_transform, after_target_transform ]) dataset = ShapesData( True, transform=transform, target_transform=target_transform ) datasets.append(dataset) return torch.utils.data.ConcatDataset(datasets) def get_loaders(augmentations, target_transform, train_batch_size): train_data = make_augmented_data(augmentations, after_target_transform=target_transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True) test_data = ShapesData(False, target_transform=target_transform) test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False) return train_loader, test_loader