From 9a69f5006be4febb28a9ae1f7b104177c06c2ed1 Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Tue, 20 Apr 2021 21:16:36 +0200 Subject: Handle dataset --- src/data.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/data.py (limited to 'src') diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..6785511 --- /dev/null +++ b/src/data.py @@ -0,0 +1,40 @@ +import os + +import numpy as np +import pandas as pd +import torch +import torchvision + +class ShapesData(torch.utils.data.Dataset): + def __init__(self, train): + self.train = train + + labels_file = './data/labels.csv' + self.labels = pd.read_csv(labels_file) + + def __len__(self): + if self.train: + return 9000 + else: + return 1000 + + def get_index(self, i): + if self.train: + assert i < 9000, 'Train dataset index out of bounds' + return i + else: + assert i < 1000, 'Test dataset index out of bounds' + return i + 9000 + + def __getitem__(self, i): + i = self.get_index(i) + filename = os.path.join('./data', self.labels.loc[i, 'name']) + image = torchvision.io.read_image(filename) + labels = self.labels.loc[i, 'squares':'left'] + sample = [ + # images as we read them have values in (0, 255) + # rescale to (-1, 1) + (image / 255 - 0.5) * 2, + torch.tensor(labels.values.astype(np.float32)) + ] + return sample -- cgit v1.2.3