m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src/data.py
blob: 67855110a6cbddedd5d2d790b4e98ea77d2091c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os

import numpy as np
import pandas as pd
import torch
import torchvision

class ShapesData(torch.utils.data.Dataset):
    def __init__(self, train):
        self.train = train

        labels_file = './data/labels.csv'
        self.labels = pd.read_csv(labels_file)

    def __len__(self):
        if self.train:
            return 9000
        else:
            return 1000

    def get_index(self, i):
        if self.train:
            assert i < 9000, 'Train dataset index out of bounds'
            return i
        else:
            assert i < 1000, 'Test dataset index out of bounds'
            return i + 9000

    def __getitem__(self, i):
        i = self.get_index(i)
        filename = os.path.join('./data', self.labels.loc[i, 'name'])
        image = torchvision.io.read_image(filename)
        labels = self.labels.loc[i, 'squares':'left']
        sample = [
            # images as we read them have values in (0, 255)
            # rescale to (-1, 1)
            (image / 255 - 0.5) * 2,
            torch.tensor(labels.values.astype(np.float32))
        ]
        return sample