m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/data.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py
new file mode 100644
index 0000000..6785511
--- /dev/null
+++ b/src/data.py
@@ -0,0 +1,40 @@
+import os
+
+import numpy as np
+import pandas as pd
+import torch
+import torchvision
+
+class ShapesData(torch.utils.data.Dataset):
+ def __init__(self, train):
+ self.train = train
+
+ labels_file = './data/labels.csv'
+ self.labels = pd.read_csv(labels_file)
+
+ def __len__(self):
+ if self.train:
+ return 9000
+ else:
+ return 1000
+
+ def get_index(self, i):
+ if self.train:
+ assert i < 9000, 'Train dataset index out of bounds'
+ return i
+ else:
+ assert i < 1000, 'Test dataset index out of bounds'
+ return i + 9000
+
+ def __getitem__(self, i):
+ i = self.get_index(i)
+ filename = os.path.join('./data', self.labels.loc[i, 'name'])
+ image = torchvision.io.read_image(filename)
+ labels = self.labels.loc[i, 'squares':'left']
+ sample = [
+ # images as we read them have values in (0, 255)
+ # rescale to (-1, 1)
+ (image / 255 - 0.5) * 2,
+ torch.tensor(labels.values.astype(np.float32))
+ ]
+ return sample