c6cb6634f2364b880f7f06d1eafa6d759bfd1bc7
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / pixel2pixel / dataset_utils.py
1 import os
2 import numpy as np
3 import torch.utils
4 import cv2
7 def split2list(images, split):
8 if isinstance(split, str):
9 with open(split) as f:
10 split_values = [x.strip() == '1' for x in f.readlines()]
11 assert(len(images) == len(split_values))
12 elif isinstance(split, float):
13 split_values = np.random.uniform(0,1,len(images)) < split
14 else:
15 assert False, 'split could not be understood'
16 #
17 train_samples = [sample for sample, sval in zip(images, split_values) if sval]
18 test_samples = [sample for sample, sval in zip(images, split_values) if not sval]
19 return train_samples, test_samples
22 def load_flo(path):
23 with open(path, 'rb') as f:
24 magic = np.fromfile(f, np.float32, count=1)
25 assert(202021.25 == magic),'Magic number incorrect. Invalid .flo file'
26 h = np.fromfile(f, np.int32, count=1)[0]
27 w = np.fromfile(f, np.int32, count=1)[0]
28 data = np.fromfile(f, np.float32, count=2*w*h)
29 # Reshape data into 3D array (columns, rows, bands)
30 data2D = np.resize(data, (w, h, 2))
31 return data2D
34 def default_loader(root, path_imgs, path_flows):
35 imgs = [os.path.join(root,path) for path in path_imgs]
36 flows = [os.path.join(root,path_flo) for path_flo in path_flows]
37 imgs = [cv2.imread(img)[:,:,::-1] for img in imgs]
38 imgs = [img.astype(np.float32) for img in imgs]
39 flows = [load_flo(flo) for flo in flows]
40 return imgs,flows
43 class ListDataset(torch.utils.data.Dataset):
44 def __init__(self, root, path_list, transform=None, loader=default_loader):
45 self.root = root
46 self.path_list = path_list
47 self.transform = transform
48 self.loader = loader
50 def __getitem__(self, index):
51 inputs, targets = self.path_list[index]
52 inputs, targets = self.loader(self.root, inputs, targets)
53 if self.transform is not None:
54 inputs, targets = self.transform(inputs, targets)
55 #
56 return inputs, targets
58 def __len__(self):
59 return len(self.path_list)
62 class ListDatasetWithAdditionalInfo(torch.utils.data.Dataset):
63 def __init__(self, root, path_list, transform=None, loader=default_loader):
64 self.root = root
65 self.path_list = path_list
66 self.transform = transform
67 self.loader = loader
69 def __getitem__(self, index):
70 inputs, targets = self.path_list[index]
71 inputs, targets, input_paths, target_paths = self.loader(self.root, inputs, targets, additional_info=True)
72 if self.transform is not None:
73 inputs, targets = self.transform(inputs, targets)
74 return inputs, targets, input_paths, target_paths
76 def __len__(self):
77 return len(self.path_list)