9b46e5c8b89549688358027b4dfc1b300208a032
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / transforms / image_transforms.py
1 from __future__ import division
2 import random
3 import numbers
4 import math
5 import cv2
6 import numpy as np
7 import PIL
9 from .image_transform_utils import *
12 class CheckImages(object):
13     def __call__(self, images, targets):
14         assert isinstance(images, (list, tuple)), 'Input must a list'
15         assert isinstance(targets, (list, tuple)), 'Target must a list'
16         #assert images[0].shape[:2] == targets[0].shape[:2], 'Image and target sizes must match.'
18         for img_idx in range(len(images)):
19             assert images[img_idx].shape[:2] == images[0].shape[:2], 'Image sizes must match. Either provide same size images or use AlignImages() instead of CheckImages()'
20             images[img_idx] = images[img_idx][...,np.newaxis] if (images[img_idx].shape) == 2 else images[img_idx]
22         for img_idx in range(len(targets)):
23             assert targets[img_idx].shape[:2] == targets[0].shape[:2], 'Target sizes must match. Either provide same size targets or use AlignImages() instead of CheckImages()'
24             targets[img_idx] = targets[img_idx][...,np.newaxis] if (targets[img_idx].shape) == 2 else targets[img_idx]
26         return images, targets
29 class AlignImages(object):
30     """Resize everything to the first image's size, before transformations begin.
31        Also make sure the images are in the desired format."""
32     def __init__(self, is_flow=None):
33         self.is_flow = is_flow
35     def __call__(self, images, targets):
36         images = images if isinstance(images, (list,tuple)) else [images]
37         images = [np.array(img) if isinstance(img,PIL.Image.Image) else img for img in images]
39         targets = targets if isinstance(targets, (list,tuple)) else [targets]
40         targets = [np.array(tgt) if isinstance(tgt,PIL.Image.Image) else tgt for tgt in targets]
42         img_size = images[0].shape[:2]
43         images, targets = Scale(img_size, is_flow=self.is_flow)(images, targets)
44         CheckImages()(images, targets)
45         return images, targets
48 class ConvertToTensor(object):
49     """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
50     def __call__(self, images, targets):
51         def func(img, img_idx):
52             img = ImageTransformUtils.array_to_tensor(img)
53             return img
55         images = ImageTransformUtils.apply_to_list(func, images)
56         targets = ImageTransformUtils.apply_to_list(func, targets)
57         return images, targets
60 class CenterCrop(object):
61     """Crops the given inputs and target arrays at the center to have a region of
62     the given size. size can be a tuple (target_height, target_width)
63     or an integer, in which case the target will be of a square shape (size, size)
64     Careful, img1 and img2 may not be the same size"""
65     def __init__(self, size):
66         self.size = (int(size), int(size)) if isinstance(size, numbers.Number) else size
68     def __call__(self, images, targets):
69         def func(img, tgt, img_idx):
70             th, tw = self.size
71             h1, w1, _ = img.shape
72             x1 = int(round((w1 - tw) / 2.))
73             y1 = int(round((h1 - th) / 2.))
74             img = img[y1: y1 + th, x1: x1 + tw]
75             tgt = tgt[y1: y1 + th, x1: x1 + tw]
76             return img, tgt
78         images = ImageTransformUtils.apply_to_list(func, images)
79         targets = ImageTransformUtils.apply_to_list(func, targets)
81         return images, targets
84 class ScaleMinSide(object):
85     """ Rescales the inputs and target arrays to the given 'size'.
86     After scaling, 'size' will be the size of the smaller edge.
87     For example, if height > width, then image will be rescaled to (size * height / width, size)"""
88     def __init__(self, size, is_flow=None):
89         self.size = size
90         self.is_flow = is_flow
92     def __call__(self, images, targets):
93         def func(img, img_idx, interpolation, is_flow):
94             h, w, _ = img.shape
95             if (w <= h and w == self.size) or (h <= w and h == self.size):
96                 ratio = 1.0
97                 size_out = (h, w)
98             else:
99                 if w < h:
100                     ratio = self.size / w
101                     size_out = (int(round(ratio * h)), self.size)
102                 else:
103                     ratio = self.size / h
104                     size_out = (self.size, int(round(ratio * w)))
105             #
107             img = ImageTransformUtils.resize_img(img, size_out, interpolation=interpolation, is_flow=is_flow)
108             return img
110         def func_img(img, img_idx):
111             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
112             return func(img, img_idx, interpolation=-1, is_flow=is_flow_img)
114         def func_tgt(img, img_idx):
115             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
116             return func(img, img_idx, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
118         images = ImageTransformUtils.apply_to_list(func_img, images)
119         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
121         return images, targets
124 class RandomCrop(object):
125     """Crops the given images"""
126     def __init__(self, size):
127         if isinstance(size, numbers.Number):
128             self.size = (int(size), int(size))
129         else:
130             self.size = size
132     def __call__(self, images, targets):
133         size_h, size_w, _ = images[0].shape
134         th, tw = self.size
135         x1 = np.random.randint(0, size_w - tw) if (size_w>tw) else 0
136         y1 = np.random.randint(0, size_h - th) if (size_h>th) else 0
138         def func(img, img_idx):
139             return img[y1:y1+th, x1:x1+tw]
141         images = ImageTransformUtils.apply_to_list(func, images)
142         targets = ImageTransformUtils.apply_to_list(func, targets)
144         return images, targets
147 class RandomHorizontalFlip(object):
148     """Randomly horizontally flips the given images"""
149     def __init__(self, is_flow=None):
150         self.is_flow = is_flow
152     def __call__(self, images, targets):
153         def func(img, img_idx, is_flow):
154             img = np.copy(np.fliplr(img))
155             if is_flow:
156                 img = ImageTransformUtils.scale_flow(img, (-1), 1)
157             return img
159         def func_img(img, img_idx):
160             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
161             img = func(img, img_idx, is_flow_img)
162             return img
164         def func_tgt(img, img_idx):
165             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
166             img = func(img, img_idx, is_flow_tgt)
167             return img
169         if np.random.random() < 0.5:
170             images = ImageTransformUtils.apply_to_list(func_img, images)
171             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
173         return images, targets
176 class RandomVerticalFlip(object):
177     """Randomly horizontally flips the given PIL.Image with a probability of 0.5
178     """
179     def __init__(self, is_flow=None):
180         self.is_flow = is_flow
182     def __call__(self, images, targets):
183         def func(img, img_idx, is_flow):
184             img = np.copy(np.flipud(img))
185             if is_flow:
186                 img = ImageTransformUtils.scale_flow(img, 1, (-1))
187             return img
189         def func_img(img, img_idx):
190             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
191             img = func(img, img_idx, is_flow_img)
192             return img
194         def func_tgt(img, img_idx):
195             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
196             img = func(img, img_idx, is_flow_tgt)
197             return img
199         if np.random.random() < 0.5:
200             images = ImageTransformUtils.apply_to_list(func_img, images)
201             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
203         return images, targets
206 class RandomRotate(object):
207     """Random rotation of the image from -angle to angle (in degrees)
208     This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
209     angle: max angle of the rotation
210     interpolation order: Default: 2 (bilinear)
211     reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
212     diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
213     """
214     def __init__(self, angle, diff_angle=0, is_flow=None):
215         self.angle = angle
216         self.diff_angle = diff_angle #if diff_angle else min(angle/2,10)
217         self.is_flow = is_flow
219     def __call__(self, images, targets):
220         applied_angle = random.uniform(-self.angle,self.angle)
221         is_input_image_pair = (len(images) == 2) and ((self.is_flow == None) or (not np.any(self.is_flow[0])))
222         diff = random.uniform(-self.diff_angle,self.diff_angle) if is_input_image_pair else 0
223         angles = [applied_angle - diff/2, applied_angle + diff/2] if is_input_image_pair else [applied_angle for img in images]
225         def func(img, img_idx, angle, interpolation, is_flow):
226             h, w = img.shape[:2]
227             angle_rad = (angle * np.pi / 180)
229             if is_flow:
230                 img = img.astype(np.float32)
231                 diff_rad = (diff * np.pi / 180)
232                 def rotate_flow(i, j, k):
233                     return -k * (j - w / 2) * diff_rad + (1 - k) * (i - h / 2) * diff_rad
234                 #
235                 rotate_flow_map = np.fromfunction(rotate_flow, img.shape)
236                 img += rotate_flow_map
238             img = ImageTransformUtils.rotate_img(img, angle, interpolation)
240             # flow vectors must be rotated too! careful about Y flow which is upside down
241             if is_flow:
242                 img = np.copy(img)
243                 img[:,:,0] = np.cos(angle_rad)*img[:,:,0] + np.sin(angle_rad)*img[:,:,1]
244                 img[:,:,1] = -np.sin(angle_rad)*img[:,:,0] + np.cos(angle_rad)*img[:,:,1]
246             return img
248         def func_img(img, img_idx):
249             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
250             interpolation = (cv2.INTER_NEAREST if is_flow_img else cv2.INTER_LINEAR)
251             img = func(img, img_idx, angles[img_idx], interpolation, is_flow_img)
252             return img
254         def func_tgt(img, img_idx):
255             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
256             interpolation = (cv2.INTER_NEAREST)
257             img = func(img, img_idx, applied_angle, interpolation, is_flow_tgt)
258             return img
260         if np.random.random() < 0.5:
261             images = ImageTransformUtils.apply_to_list(func_img, images)
262             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
264         return images, targets
267 class RandomColorWarp(object):
268     def __init__(self, mean_range=0, std_range=0, is_flow=None):
269         self.mean_range = mean_range
270         self.std_range = std_range
271         self.is_flow = is_flow
273     def __call__(self, images, target):
274         if np.random.random() < 0.5:
275             if self.std_range != 0:
276                 random_std = np.random.uniform(-self.std_range, self.std_range, 3)
277                 for img_idx in range(len(images)):
278                     is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
279                     if not is_flow_img:
280                         images[img_idx] *= (1 + random_std)
282             if self.mean_range != 0:
283                 random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3)
284                 for img_idx in range(len(images)):
285                     is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
286                     if not is_flow_img:
287                         images[img_idx] += random_mean
289             for img_idx in range(len(images)):
290                 is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
291                 if not is_flow_img:
292                     random_order = np.random.permutation(3)
293                     images[img_idx] = images[img_idx][:,:,random_order]
295         return images, target
298 class RandomColor2Gray(object):
299     def __init__(self, mean_range=0, std_range=0, is_flow=None, random_threshold=0.25):
300         self.mean_range = mean_range
301         self.std_range = std_range
302         self.is_flow = is_flow
303         self.random_threshold = random_threshold
305     def __call__(self, images, target):
306         if np.random.random() < self.random_threshold:
307             for img_idx in range(len(images)):
308                 is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
309                 if not is_flow_img:
310                     images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2GRAY)
311                     images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_GRAY2RGB)
312         return images, target
315 class RandomScaleCrop(object):
316     """Randomly zooms images up to 15% and crop them to keep same size as before."""
317     def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False):
318         self.img_resize = img_resize
319         self.scale_range = scale_range
320         self.is_flow = is_flow
321         self.center_crop = center_crop
323     @staticmethod
324     def get_params(img, img_resize, scale_range, center_crop):
325         in_h, in_w = img.shape[:2]
326         out_h, out_w = img_resize
327         # this random scaling is w.r.t. the output size
328         if (np.random.random() < 0.5):
329             resize_h = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_h))
330             resize_w = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_w))
331         else:
332             resize_h, resize_w = out_h, out_w
334         # crop params w.r.t the scaled size
335         out_r = (resize_h - out_h)//2 if center_crop else np.random.randint(resize_h - out_h + 1)
336         out_c = (resize_w - out_w)//2 if center_crop else np.random.randint(resize_w - out_w + 1)
337         return out_r, out_c, out_h, out_w, resize_h, resize_w
339     def __call__(self, images, targets):
340         out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop)
342         def func_img(img, img_idx):
343             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
344             img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img)
345             return img
347         def func_tgt(img, img_idx):
348             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
349             img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
350             return img
352         images = ImageTransformUtils.apply_to_list(func_img, images)
353         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
354         return images, targets
357 class RandomCropScale(object):
358     """Crop the Image to random size and scale to the given resolution"""
359     def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False):
360         self.size = size if (type(size) in (list,tuple)) else (size, size)
361         self.crop_range = crop_range
362         self.is_flow = is_flow
363         self.center_crop = center_crop
365     @staticmethod
366     def get_params(img, crop_range, center_crop):
367         h_orig = img.shape[0]; w_orig = img.shape[1]
368         r_wh = (w_orig/h_orig)
369         ratio =  (r_wh/2.0, 2.0*r_wh)
370         for attempt in range(10):
371             area = h_orig * w_orig
372             target_area = random.uniform(*crop_range) * area
373             aspect_ratio = random.uniform(*ratio)
374             w = int(round(math.sqrt(target_area * aspect_ratio)))
375             h = int(round(math.sqrt(target_area / aspect_ratio)))
376             if (h <= h_orig) and (w <= w_orig):
377                 i = (h_orig - h)//2 if center_crop else random.randint(0, h_orig - h)
378                 j = (w_orig - w)//2 if center_crop else random.randint(0, w_orig - w)
379                 return i, j, h, w
381         # Fallback: entire image
382         return 0, 0, h_orig, w_orig
384     def __call__(self, images, targets):
385         out_r, out_c, out_h, out_w = self.get_params(images[0], self.crop_range, self.center_crop)
387         def func_img(img, img_idx):
388             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
389             img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img)
390             return img
392         def func_tgt(img, img_idx):
393             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
394             img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
395             return img
397         images = ImageTransformUtils.apply_to_list(func_img, images)
398         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
400         return images, targets
403 class Scale(object):
404     def __init__(self, img_size, target_size=None, is_flow=None):
405         self.img_size = img_size
406         self.target_size = target_size if target_size else img_size
407         self.is_flow = is_flow
409     def __call__(self, images, targets):
410         if self.img_size is None:
411             return images, targets
413         def func_img(img, img_idx):
414             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
415             img = ImageTransformUtils.resize_img(img, self.img_size, interpolation=-1, is_flow=is_flow_img)
416             return img
418         def func_tgt(img, img_idx):
419             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
420             img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
421             return img
423         images = ImageTransformUtils.apply_to_list(func_img, images)
424         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
426         return images, targets
429 class CropRect(object):
430     def __init__(self, crop_rect):
431         self.crop_rect = crop_rect
433     def __call__(self, images, targets):
434         if self.crop_rect is None:
435             return images, targets
437         def func(img, tgt, img_idx):
438             img_size = img.shape
439             crop_rect = self.crop_rect
440             max_val = max(crop_rect)
441             t, l, h, w = crop_rect
442             if max_val <= 1:  # convert float into integer
443                 t = int(t * img_size[0] + 0.5)  # top
444                 l = int(l * img_size[1] + 0.5)  # left
445                 h = int(h * img_size[0] + 0.5)  # height
446                 w = int(w * img_size[1] + 0.5)  # width
447             else:
448                 t = int(t)  # top
449                 l = int(l)  # left
450                 h = int(h)  # height
451                 w = int(w)  # width
453             img = img[t:(t+h), l:(l+w)]
454             tgt = tgt[t:(t+h), l:(l+w)]
456             return img, tgt
458         images = ImageTransformUtils.apply_to_list(func, images)
459         targets = ImageTransformUtils.apply_to_list(func, targets)
461         return images, targets
464 class MaskTarget(object):
465     def __init__(self, mask_rect, mask_val):
466         self.mask_rect = mask_rect
467         self.mask_val = mask_val
469     def __call__(self, images, targets):
470         if self.mask_rect is None:
471             return images, targets
473         def func(img, tgt, img_idx):
474             img_size = img.shape
475             crop_rect = self.mask_rect
476             max_val = max(crop_rect)
477             t, l, h, w = crop_rect
478             if max_val <= 1:  # convert float into integer
479                 t = int(t * img_size[0] + 0.5)  # top
480                 l = int(l * img_size[1] + 0.5)  # left
481                 h = int(h * img_size[0] + 0.5)  # height
482                 w = int(w * img_size[1] + 0.5)  # width
483             else:
484                 t = int(t)  # top
485                 l = int(l)  # left
486                 h = int(h)  # height
487                 w = int(w)  # width
489             tgt[t:(t+h), l:(l+w)] = self.mask_val
490             return img, tgt
492         images = ImageTransformUtils.apply_to_list(func, images)
493         targets = ImageTransformUtils.apply_to_list(func, targets)
495         return images, targets
498 class NormalizeMeanStd(object):
499     def __init__(self, mean, std):
500         self.mean = mean
501         self.std = std
503     def __call__(self, images, target):
504         if isinstance(images, (list,tuple)):
505             images = [(img-self.mean)/self.std for img in images]
506         else:
507             images = (images-self.mean)/self.std
509         return images, target
512 class NormalizeMeanScale(object):
513     def __init__(self, mean, scale):
514         self.mean = mean
515         self.scale = scale
517     def __call__(self, images, target):
518         if isinstance(images, (list,tuple)):
519             images = [(img-self.mean)*self.scale for img in images]
520         else:
521             images = (images-self.mean)*self.scale
523         return images, target