diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-04-20 21:16:36 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-04-20 21:21:40 +0200 |
commit | 9a69f5006be4febb28a9ae1f7b104177c06c2ed1 (patch) | |
tree | 9f09b209c68512e902d52f482a5cb324b202e8cf /src | |
parent | 73e38329db99b56bd40385e676120f0b163f030d (diff) |
Handle dataset
Diffstat (limited to 'src')
-rw-r--r-- | src/data.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..6785511 --- /dev/null +++ b/src/data.py @@ -0,0 +1,40 @@ +import os + +import numpy as np +import pandas as pd +import torch +import torchvision + +class ShapesData(torch.utils.data.Dataset): + def __init__(self, train): + self.train = train + + labels_file = './data/labels.csv' + self.labels = pd.read_csv(labels_file) + + def __len__(self): + if self.train: + return 9000 + else: + return 1000 + + def get_index(self, i): + if self.train: + assert i < 9000, 'Train dataset index out of bounds' + return i + else: + assert i < 1000, 'Test dataset index out of bounds' + return i + 9000 + + def __getitem__(self, i): + i = self.get_index(i) + filename = os.path.join('./data', self.labels.loc[i, 'name']) + image = torchvision.io.read_image(filename) + labels = self.labels.loc[i, 'squares':'left'] + sample = [ + # images as we read them have values in (0, 255) + # rescale to (-1, 1) + (image / 255 - 0.5) * 2, + torch.tensor(labels.values.astype(np.float32)) + ] + return sample |