blob: 67855110a6cbddedd5d2d790b4e98ea77d2091c6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
import os
import numpy as np
import pandas as pd
import torch
import torchvision
class ShapesData(torch.utils.data.Dataset):
def __init__(self, train):
self.train = train
labels_file = './data/labels.csv'
self.labels = pd.read_csv(labels_file)
def __len__(self):
if self.train:
return 9000
else:
return 1000
def get_index(self, i):
if self.train:
assert i < 9000, 'Train dataset index out of bounds'
return i
else:
assert i < 1000, 'Test dataset index out of bounds'
return i + 9000
def __getitem__(self, i):
i = self.get_index(i)
filename = os.path.join('./data', self.labels.loc[i, 'name'])
image = torchvision.io.read_image(filename)
labels = self.labels.loc[i, 'squares':'left']
sample = [
# images as we read them have values in (0, 255)
# rescale to (-1, 1)
(image / 255 - 0.5) * 2,
torch.tensor(labels.values.astype(np.float32))
]
return sample
|