30e8815907b7b2ac30675dd4180a37e09e632878
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / pixel2pixel / mpisintel.py
1 import os.path
2 import glob
3 from .dataset_utils import split2list, ListDataset
5 # Requirements: Numpy as PIL/Pillow
6 import numpy as np
7 import cv2
9 __all__ = ['mpi_sintel_clean','mpi_sintel_final','mpi_sintel_both', 'mpi_sintel_depth', 'mpi_sintel_sceneflow']
11 # Check for endianness, based on Daniel Scharstein's optical flow code.
12 # Using little-endian architecture, these two should be equal.
13 TAG_FLOAT = 202021.25
14 TAG_CHAR = 'PIEH'
15 MAX_DEPTH = 150
17 '''
18 Dataset routines for MPI Sintel.
19 http://sintel.is.tue.mpg.de/
20 clean version imgs are without shaders, final version imgs are fully rendered
21 The dataset is not very big, you might want to only pretrain on it for flownet
22 '''
23 ############################################################################
24 def mpi_sintel_clean(dataset_config, root, split=None, transforms=None):
25 train_list, test_list = make_dataset_flow(root, split, 'clean')
26 train_dataset = ListDataset(root, train_list, transforms[0])
27 test_dataset = ListDataset(root, test_list, transforms[1])
28 return train_dataset, test_dataset
31 def mpi_sintel_final(dataset_config, root, split=None, transforms=None):
32 train_list, test_list = make_dataset_flow(root, split, 'final')
33 train_dataset = ListDataset(root, train_list, transforms[0])
34 test_dataset = ListDataset(root, test_list, transforms[1])
35 return train_dataset, test_dataset
38 def mpi_sintel_both(dataset_config, root, split=None, transforms=None):
39 '''load images from both clean and final folders.
40 We cannot shuffle input, because it would very likely cause data snooping
41 for the clean and final frames are not that different'''
42 train_list1, test_list1 = make_dataset_flow(root, split, 'clean')
43 train_list2, test_list2 = make_dataset_flow(root, split, 'final')
44 train_dataset = ListDataset(root, train_list1 + train_list2, transforms[0])
45 test_dataset = ListDataset(root, test_list1 + test_list2, transforms[1])
46 return train_dataset, test_dataset
49 ############################################################################
50 def mpi_sintel_depth(dataset_config, root, split=None, transforms=None):
51 train_list, test_list = make_dataset_depth(root, split, 'depth', num_target=1)
52 train_dataset = ListDataset(root, train_list, transforms[0], loader=mpi_sintel_depth_loader1)
53 test_dataset = ListDataset(root, test_list, transforms[1], loader=mpi_sintel_depth_loader1)
54 return train_dataset, test_dataset
57 def mpi_sintel_sceneflow(dataset_config, root, split=None, transforms=None):
58 train_list, test_list = make_dataset_depth(root, split, 'sceneflow', num_target=3)
59 train_dataset = ListDataset(root, train_list, transforms[0], loader=mpi_sintel_sceneflow_loader)
60 test_dataset = ListDataset(root, test_list, transforms[1], loader=mpi_sintel_sceneflow_loader)
61 return train_dataset, test_dataset
64 ############################################################################
65 # internal functions
66 ############################################################################
67 def load_flo(path):
68 with open(path, 'rb') as f:
69 magic = np.fromfile(f, np.float32, count=1)
70 assert(202021.25 == magic),'Magic number incorrect. Invalid .flo file'
71 h = np.fromfile(f, np.int32, count=1)[0]
72 w = np.fromfile(f, np.int32, count=1)[0]
73 data = np.fromfile(f, np.float32, count=2*w*h)
74 # Reshape data into 3D array (columns, rows, bands)
75 data2D = np.resize(data, (w, h, 2))
76 return data2D
79 def depth_read(filename):
80 """ Read depth data from file, return as numpy array. """
81 f = open(filename,'rb')
82 check = np.fromfile(f,dtype=np.float32,count=1)[0]
83 assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
84 width = np.fromfile(f,dtype=np.int32,count=1)[0]
85 height = np.fromfile(f,dtype=np.int32,count=1)[0]
86 size = width*height
87 assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
88 depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
89 depth = np.minimum(depth,MAX_DEPTH)
90 return depth
94 def make_dataset_flow(dir, split, dataset_type='clean'):
95 flow_dir = 'flow'
96 img_dir = dataset_type
97 images = []
98 for flow_map in glob.iglob(os.path.join(dir,flow_dir,'*','*.flo')):
99 flow_map = os.path.relpath(flow_map,os.path.join(dir,flow_dir))
100 root_filename = flow_map[:-8]
101 frame_nb = int(flow_map[-8:-4])
102 img1 = os.path.join(img_dir,root_filename+str(frame_nb).zfill(4)+'.png')
103 img2 = os.path.join(img_dir,root_filename+str(frame_nb+1).zfill(4)+'.png')
104 flow_map = os.path.join(flow_dir,flow_map)
105 if not (os.path.isfile(os.path.join(dir,img1)) or os.path.isfile(os.path.join(dir,img2))):
106 continue
108 imgs = [img1,img2]
109 flow_maps = [flow_map]
110 images.append([imgs,flow_maps])
112 return split2list(images, split)
115 ######################################################################################
117 def make_dataset_depth(dir, split, dataset_type='depth', num_target=None):
118 depth_dir = 'depth'
119 img_dir = 'final'
120 images = []
121 folder_list = sorted(glob.glob(os.path.join(dir,depth_dir,'*')))
122 for folder in folder_list:
123 depth_list = sorted(glob.glob(os.path.join(dir,depth_dir,folder,'*.dpt')))
124 for depth_map1 in depth_list[:-1]:
125 depth_map1 = os.path.relpath(depth_map1,os.path.join(dir,depth_dir))
126 root_filename = depth_map1[:-8]
127 frame_nb = int(depth_map1[-8:-4])
128 img1 = os.path.join(img_dir,root_filename+str(frame_nb).zfill(4)+'.png')
129 img2 = os.path.join(img_dir,root_filename+str(frame_nb+1).zfill(4)+'.png')
130 if not (os.path.isfile(os.path.join(dir,img1)) or os.path.isfile(os.path.join(dir,img2))):
131 continue
133 depth_map1 = os.path.join(depth_dir, depth_map1)
134 if num_target == 2:
135 depth_map2 = os.path.join(depth_dir, root_filename + str(frame_nb + 1).zfill(4) + '.dpt')
136 images.append([[img1,img2],[depth_map1,depth_map2]])
137 else:
138 images.append([[img1,img2],[depth_map1]])
140 return split2list(images, split, default_split=0.87)
142 # target depth for only the first image
143 def mpi_sintel_depth_loader1(root, path_imgs, path_depths):
144 path_depths = [os.path.join(root,path) for path in path_depths]
145 depth_img = [depth_read(path_depths[0])]
146 imgs = [os.path.join(root,path) for path in path_imgs]
147 imgs = [cv2.imread(img) for img in imgs]
148 imgs = [img[:,:,::-1] for img in imgs]
149 imgs = [img.astype(np.float32) for img in imgs]
150 return imgs,depth_img
153 # target sceneflow (depth is for only the first image)
154 def mpi_sintel_sceneflow_loader(root, path_imgs, path_target):
155 path_target = [os.path.join(root,path) for path in path_target]
156 flow_img = load_flo(path_target[0])
157 depth_img = depth_read(path_target[1])[...,np.newaxis]
158 imgs = [os.path.join(root,path) for path in path_imgs]
160 imgs = [cv2.imread(img) for img in imgs]
161 imgs = [img[:,:,::-1] for img in imgs]
162 imgs = [img.astype(np.float32) for img in imgs]
164 target_imgs = (flow_img,depth_img)
165 return imgs,target_imgs