m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-04-29 11:33:08 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-04-29 11:33:08 +0200
commit9e22e1e8a221ce55789438956055bb02b3069162 (patch)
tree2f12435502eac1e0675f889b9ca1318b2f82dbad
parent9a69f5006be4febb28a9ae1f7b104177c06c2ed1 (diff)
Add augmentations
-rw-r--r--src/data.py113
1 files changed, 108 insertions, 5 deletions
diff --git a/src/data.py b/src/data.py
index 6785511..ff7aca5 100644
--- a/src/data.py
+++ b/src/data.py
@@ -4,10 +4,13 @@ import numpy as np
import pandas as pd
import torch
import torchvision
+import torchvision.transforms.functional as F
class ShapesData(torch.utils.data.Dataset):
- def __init__(self, train):
+ def __init__(self, train, transform=None, target_transform=None):
self.train = train
+ self.transform = transform
+ self.target_transform = target_transform
labels_file = './data/labels.csv'
self.labels = pd.read_csv(labels_file)
@@ -29,12 +32,112 @@ class ShapesData(torch.utils.data.Dataset):
def __getitem__(self, i):
i = self.get_index(i)
filename = os.path.join('./data', self.labels.loc[i, 'name'])
+
image = torchvision.io.read_image(filename)
+ # Images as we read them have four channels, each with values in (0,
+ # 255) Take the first channel and rescale to (-1, 1)
+ image = (image[0] / 255 - 0.5) * 2
+ if self.transform:
+ image = self.transform(image)
+
labels = self.labels.loc[i, 'squares':'left']
+ labels = torch.tensor(labels.values.astype(np.float32))
+ if self.target_transform:
+ labels = self.target_transform(labels)
+
sample = [
- # images as we read them have values in (0, 255)
- # rescale to (-1, 1)
- (image / 255 - 0.5) * 2,
- torch.tensor(labels.values.astype(np.float32))
+ image,
+ labels
]
return sample
+
+def rotation_transform(image, rotation):
+ return F.rotate(image.unsqueeze(0), angle=(rotation * 90))
+
+def rotation_target_transform(labels, rotation):
+ labels[2:6] = labels[2:6][[*range(rotation, 4), *range(0, rotation)]]
+ return labels
+
+def make_rotation(rotation):
+ return torchvision.transforms.Lambda(
+ lambda image: rotation_transform(image, rotation)
+ ), torchvision.transforms.Lambda(
+ lambda labels: rotation_target_transform(labels, rotation)
+ )
+
+def vertical_flip_target_transform(labels):
+ labels[[2, 4]] = labels[[4, 2]]
+ return labels
+
+def make_vertical_flip():
+ return torchvision.transforms.Lambda(
+ F.vflip
+ ), torchvision.transforms.Lambda(
+ vertical_flip_target_transform
+ )
+
+def horizontal_flip_target_transform(labels):
+ labels[[3, 5]] = labels[[5, 3]]
+ return labels
+
+def make_horizontal_flip():
+ return torchvision.transforms.Lambda(
+ F.hflip
+ ), torchvision.transforms.Lambda(
+ horizontal_flip_target_transform
+ )
+
+def make_transforms(rotation, flip_vertical, flip_horizontal):
+ """
+ Returns an image and label transform for rotating and/or flipping images.
+ - rotation: 0-3, indicating number of counter clockwise 90° turns
+ - flip_vertical, flip_horizontal: booleans
+ """
+ transforms = []
+ target_transforms = []
+ if rotation > 0:
+ rotation, target_rotation = make_rotation(rotation)
+ transforms.append(rotation)
+ target_transforms.append(target_rotation)
+ if flip_vertical:
+ flip, target_flip = make_vertical_flip()
+ transforms.append(flip)
+ target_transforms.append(target_flip)
+ if flip_horizontal:
+ flip, target_flip = make_horizontal_flip()
+ transforms.append(flip)
+ target_transforms.append(target_flip)
+
+ if transforms:
+ # Transformations add another dimension, which screws up the net
+ transforms.append(torchvision.transforms.Lambda(lambda image:
+ image.squeeze()
+ ))
+
+ return (torchvision.transforms.Compose(transforms),
+ torchvision.transforms.Compose(target_transforms))
+
+def make_augmented_data(augmentations, after_target_transform=None):
+ """
+ Returns concatanated ShapesData train datasets, each modified by an
+ augmentation.
+ - augmentations: list of triples (rotation, vertical flip, horizontal flip).
+ - after_target_transform: an additional transform to apply after the
+ augmentation ones.
+ """
+ datasets = []
+ for rotation, vertical_flip, horizontal_flip in augmentations:
+ transform, target_transform = make_transforms(
+ rotation, vertical_flip, horizontal_flip
+ )
+ if after_target_transform:
+ target_transform = torchvision.transforms.Compose([
+ target_transform,
+ after_target_transform
+ ])
+ dataset = ShapesData(
+ True, transform=transform, target_transform=target_transform
+ )
+ datasets.append(dataset)
+
+ return torch.utils.data.ConcatDataset(datasets)