diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/data.py | 113 |
1 files changed, 108 insertions, 5 deletions
diff --git a/src/data.py b/src/data.py index 6785511..ff7aca5 100644 --- a/src/data.py +++ b/src/data.py @@ -4,10 +4,13 @@ import numpy as np import pandas as pd import torch import torchvision +import torchvision.transforms.functional as F class ShapesData(torch.utils.data.Dataset): - def __init__(self, train): + def __init__(self, train, transform=None, target_transform=None): self.train = train + self.transform = transform + self.target_transform = target_transform labels_file = './data/labels.csv' self.labels = pd.read_csv(labels_file) @@ -29,12 +32,112 @@ class ShapesData(torch.utils.data.Dataset): def __getitem__(self, i): i = self.get_index(i) filename = os.path.join('./data', self.labels.loc[i, 'name']) + image = torchvision.io.read_image(filename) + # Images as we read them have four channels, each with values in (0, + # 255) Take the first channel and rescale to (-1, 1) + image = (image[0] / 255 - 0.5) * 2 + if self.transform: + image = self.transform(image) + labels = self.labels.loc[i, 'squares':'left'] + labels = torch.tensor(labels.values.astype(np.float32)) + if self.target_transform: + labels = self.target_transform(labels) + sample = [ - # images as we read them have values in (0, 255) - # rescale to (-1, 1) - (image / 255 - 0.5) * 2, - torch.tensor(labels.values.astype(np.float32)) + image, + labels ] return sample + +def rotation_transform(image, rotation): + return F.rotate(image.unsqueeze(0), angle=(rotation * 90)) + +def rotation_target_transform(labels, rotation): + labels[2:6] = labels[2:6][[*range(rotation, 4), *range(0, rotation)]] + return labels + +def make_rotation(rotation): + return torchvision.transforms.Lambda( + lambda image: rotation_transform(image, rotation) + ), torchvision.transforms.Lambda( + lambda labels: rotation_target_transform(labels, rotation) + ) + +def vertical_flip_target_transform(labels): + labels[[2, 4]] = labels[[4, 2]] + return labels + +def make_vertical_flip(): + return torchvision.transforms.Lambda( + F.vflip + ), torchvision.transforms.Lambda( + vertical_flip_target_transform + ) + +def horizontal_flip_target_transform(labels): + labels[[3, 5]] = labels[[5, 3]] + return labels + +def make_horizontal_flip(): + return torchvision.transforms.Lambda( + F.hflip + ), torchvision.transforms.Lambda( + horizontal_flip_target_transform + ) + +def make_transforms(rotation, flip_vertical, flip_horizontal): + """ + Returns an image and label transform for rotating and/or flipping images. + - rotation: 0-3, indicating number of counter clockwise 90° turns + - flip_vertical, flip_horizontal: booleans + """ + transforms = [] + target_transforms = [] + if rotation > 0: + rotation, target_rotation = make_rotation(rotation) + transforms.append(rotation) + target_transforms.append(target_rotation) + if flip_vertical: + flip, target_flip = make_vertical_flip() + transforms.append(flip) + target_transforms.append(target_flip) + if flip_horizontal: + flip, target_flip = make_horizontal_flip() + transforms.append(flip) + target_transforms.append(target_flip) + + if transforms: + # Transformations add another dimension, which screws up the net + transforms.append(torchvision.transforms.Lambda(lambda image: + image.squeeze() + )) + + return (torchvision.transforms.Compose(transforms), + torchvision.transforms.Compose(target_transforms)) + +def make_augmented_data(augmentations, after_target_transform=None): + """ + Returns concatanated ShapesData train datasets, each modified by an + augmentation. + - augmentations: list of triples (rotation, vertical flip, horizontal flip). + - after_target_transform: an additional transform to apply after the + augmentation ones. + """ + datasets = [] + for rotation, vertical_flip, horizontal_flip in augmentations: + transform, target_transform = make_transforms( + rotation, vertical_flip, horizontal_flip + ) + if after_target_transform: + target_transform = torchvision.transforms.Compose([ + target_transform, + after_target_transform + ]) + dataset = ShapesData( + True, transform=transform, target_transform=target_transform + ) + datasets.append(dataset) + + return torch.utils.data.ConcatDataset(datasets) |