torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / transforms / image_transforms.py
1 from __future__ import division
2 import random
3 import numbers
4 import math
5 import cv2
6 import numpy as np
7 import PIL
9 from .image_transform_utils import *
12 class CheckImages(object):
13     def __call__(self, images, targets):
14         assert isinstance(images, (list, tuple)), 'Input must a list'
15         assert isinstance(targets, (list, tuple)), 'Target must a list'
16         #assert images[0].shape[:2] == targets[0].shape[:2], 'Image and target sizes must match.'
18         for img_idx in range(len(images)):
19             assert images[img_idx].shape[:2] == images[0].shape[:2], 'Image sizes must match. Either provide same size images or use AlignImages() instead of CheckImages()'
20             images[img_idx] = images[img_idx][...,np.newaxis] if (images[img_idx].shape) == 2 else images[img_idx]
22         for img_idx in range(len(targets)):
23             assert targets[img_idx].shape[:2] == targets[0].shape[:2], 'Target sizes must match. Either provide same size targets or use AlignImages() instead of CheckImages()'
24             targets[img_idx] = targets[img_idx][...,np.newaxis] if (targets[img_idx].shape) == 2 else targets[img_idx]
26         return images, targets
29 class AlignImages(object):
30     """Resize everything to the first image's size, before transformations begin.
31        Also make sure the images are in the desired format."""
32     def __init__(self, is_flow=None):
33         self.is_flow = is_flow
35     def __call__(self, images, targets):
36         images = images if isinstance(images, (list,tuple)) else [images]
37         images = [np.array(img) if isinstance(img,PIL.Image.Image) else img for img in images]
39         targets = targets if isinstance(targets, (list,tuple)) else [targets]
40         targets = [np.array(tgt) if isinstance(tgt,PIL.Image.Image) else tgt for tgt in targets]
42         img_size = images[0].shape[:2]
43         images, targets = Scale(img_size, is_flow=self.is_flow)(images, targets)
44         CheckImages()(images, targets)
45         return images, targets
48 class ConvertToTensor(object):
49     """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
50     def __call__(self, images, targets):
51         def func(imgs, img_idx):
52             imgs = [ImageTransformUtils.array_to_tensor(img_plane) for img_plane in imgs] \
53                 if isinstance(imgs, list) else ImageTransformUtils.array_to_tensor(imgs)
54             return imgs
56         images = ImageTransformUtils.apply_to_list(func, images)
57         targets = ImageTransformUtils.apply_to_list(func, targets)
58         return images, targets
61 class CenterCrop(object):
62     """Crops the given inputs and target arrays at the center to have a region of
63     the given size. size can be a tuple (target_height, target_width)
64     or an integer, in which case the target will be of a square shape (size, size)
65     Careful, img1 and img2 may not be the same size"""
66     def __init__(self, size):
67         self.size = (int(size), int(size)) if isinstance(size, numbers.Number) else size
69     def __call__(self, images, targets):
70         def func(img, tgt, img_idx):
71             th, tw = self.size
72             h1, w1, _ = img.shape
73             x1 = int(round((w1 - tw) / 2.))
74             y1 = int(round((h1 - th) / 2.))
75             img = img[y1: y1 + th, x1: x1 + tw]
76             tgt = tgt[y1: y1 + th, x1: x1 + tw]
77             return img, tgt
79         images = ImageTransformUtils.apply_to_list(func, images)
80         targets = ImageTransformUtils.apply_to_list(func, targets)
82         return images, targets
85 class ScaleMinSide(object):
86     """ Rescales the inputs and target arrays to the given 'size'.
87     After scaling, 'size' will be the size of the smaller edge.
88     For example, if height > width, then image will be rescaled to (size * height / width, size)"""
89     def __init__(self, size, is_flow=None):
90         self.size = size
91         self.is_flow = is_flow
93     def __call__(self, images, targets):
94         def func(img, img_idx, interpolation, is_flow):
95             h, w, _ = img.shape
96             if (w <= h and w == self.size) or (h <= w and h == self.size):
97                 ratio = 1.0
98                 size_out = (h, w)
99             else:
100                 if w < h:
101                     ratio = self.size / w
102                     size_out = (int(round(ratio * h)), self.size)
103                 else:
104                     ratio = self.size / h
105                     size_out = (self.size, int(round(ratio * w)))
106             #
108             img = ImageTransformUtils.resize_img(img, size_out, interpolation=interpolation, is_flow=is_flow)
109             return img
111         def func_img(img, img_idx):
112             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
113             return func(img, img_idx, interpolation=-1, is_flow=is_flow_img)
115         def func_tgt(img, img_idx):
116             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
117             return func(img, img_idx, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
119         images = ImageTransformUtils.apply_to_list(func_img, images)
120         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
122         return images, targets
125 class RandomCrop(object):
126     """Crops the given images"""
127     def __init__(self, size):
128         if isinstance(size, numbers.Number):
129             self.size = (int(size), int(size))
130         else:
131             self.size = size
133     def __call__(self, images, targets):
134         size_h, size_w, _ = images[0].shape
135         th, tw = self.size
136         x1 = np.random.randint(0, size_w - tw) if (size_w>tw) else 0
137         y1 = np.random.randint(0, size_h - th) if (size_h>th) else 0
139         def func(img, img_idx):
140             return img[y1:y1+th, x1:x1+tw]
142         images = ImageTransformUtils.apply_to_list(func, images)
143         targets = ImageTransformUtils.apply_to_list(func, targets)
145         return images, targets
148 class RandomHorizontalFlip(object):
149     """Randomly horizontally flips the given images"""
150     def __init__(self, is_flow=None):
151         self.is_flow = is_flow
153     def __call__(self, images, targets):
154         def func(img, img_idx, is_flow):
155             img = np.copy(np.fliplr(img))
156             if is_flow:
157                 img = ImageTransformUtils.scale_flow(img, (-1), 1)
158             return img
160         def func_img(img, img_idx):
161             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
162             img = func(img, img_idx, is_flow_img)
163             return img
165         def func_tgt(img, img_idx):
166             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
167             img = func(img, img_idx, is_flow_tgt)
168             return img
170         if np.random.random() < 0.5:
171             images = ImageTransformUtils.apply_to_list(func_img, images)
172             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
174         return images, targets
177 class RandomVerticalFlip(object):
178     """Randomly horizontally flips the given PIL.Image with a probability of 0.5
179     """
180     def __init__(self, is_flow=None):
181         self.is_flow = is_flow
183     def __call__(self, images, targets):
184         def func(img, img_idx, is_flow):
185             img = np.copy(np.flipud(img))
186             if is_flow:
187                 img = ImageTransformUtils.scale_flow(img, 1, (-1))
188             return img
190         def func_img(img, img_idx):
191             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
192             img = func(img, img_idx, is_flow_img)
193             return img
195         def func_tgt(img, img_idx):
196             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
197             img = func(img, img_idx, is_flow_tgt)
198             return img
200         if np.random.random() < 0.5:
201             images = ImageTransformUtils.apply_to_list(func_img, images)
202             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
204         return images, targets
207 class RandomRotate(object):
208     """Random rotation of the image from -angle to angle (in degrees)
209     This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
210     angle: max angle of the rotation
211     interpolation order: Default: 2 (bilinear)
212     reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
213     diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
214     """
215     def __init__(self, angle, diff_angle=0, is_flow=None):
216         self.angle = angle
217         self.diff_angle = diff_angle #if diff_angle else min(angle/2,10)
218         self.is_flow = is_flow
220     def __call__(self, images, targets):
221         applied_angle = random.uniform(-self.angle,self.angle)
222         is_input_image_pair = (len(images) == 2) and ((self.is_flow == None) or (not np.any(self.is_flow[0])))
223         diff = random.uniform(-self.diff_angle,self.diff_angle) if is_input_image_pair else 0
224         angles = [applied_angle - diff/2, applied_angle + diff/2] if is_input_image_pair else [applied_angle for img in images]
226         def func(img, img_idx, angle, interpolation, is_flow):
227             h, w = img.shape[:2]
228             angle_rad = (angle * np.pi / 180)
230             if is_flow:
231                 img = img.astype(np.float32)
232                 diff_rad = (diff * np.pi / 180)
233                 def rotate_flow(i, j, k):
234                     return -k * (j - w / 2) * diff_rad + (1 - k) * (i - h / 2) * diff_rad
235                 #
236                 rotate_flow_map = np.fromfunction(rotate_flow, img.shape)
237                 img += rotate_flow_map
239             img = ImageTransformUtils.rotate_img(img, angle, interpolation)
241             # flow vectors must be rotated too! careful about Y flow which is upside down
242             if is_flow:
243                 img = np.copy(img)
244                 img[:,:,0] = np.cos(angle_rad)*img[:,:,0] + np.sin(angle_rad)*img[:,:,1]
245                 img[:,:,1] = -np.sin(angle_rad)*img[:,:,0] + np.cos(angle_rad)*img[:,:,1]
247             return img
249         def func_img(img, img_idx):
250             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
251             interpolation = (cv2.INTER_NEAREST if is_flow_img else cv2.INTER_LINEAR)
252             img = func(img, img_idx, angles[img_idx], interpolation, is_flow_img)
253             return img
255         def func_tgt(img, img_idx):
256             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
257             interpolation = (cv2.INTER_NEAREST)
258             img = func(img, img_idx, applied_angle, interpolation, is_flow_tgt)
259             return img
261         if np.random.random() < 0.5:
262             images = ImageTransformUtils.apply_to_list(func_img, images)
263             targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
265         return images, targets
268 class RandomColorWarp(object):
269     def __init__(self, mean_range=0, std_range=0, is_flow=None):
270         self.mean_range = mean_range
271         self.std_range = std_range
272         self.is_flow = is_flow
274     def __call__(self, images, target):
275         if np.random.random() < 0.5:
276             if self.std_range != 0:
277                 random_std = np.random.uniform(-self.std_range, self.std_range, 3)
278                 for img_idx in range(len(images)):
279                     is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
280                     if not is_flow_img:
281                         images[img_idx] *= (1 + random_std)
283             if self.mean_range != 0:
284                 random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3)
285                 for img_idx in range(len(images)):
286                     is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
287                     if not is_flow_img:
288                         images[img_idx] += random_mean
290             for img_idx in range(len(images)):
291                 is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
292                 if not is_flow_img:
293                     random_order = np.random.permutation(3)
294                     images[img_idx] = images[img_idx][:,:,random_order]
296         return images, target
299 class RandomColor2Gray(object):
300     def __init__(self, mean_range=0, std_range=0, is_flow=None, random_threshold=0.25):
301         self.mean_range = mean_range
302         self.std_range = std_range
303         self.is_flow = is_flow
304         self.random_threshold = random_threshold
306     def __call__(self, images, target):
307         if np.random.random() < self.random_threshold:
308             for img_idx in range(len(images)):
309                 is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
310                 if not is_flow_img:
311                     images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2GRAY)
312                     images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_GRAY2RGB)
313         return images, target
316 class RandomScaleCrop(object):
317     """Randomly zooms images up to 15% and crop them to keep same size as before."""
318     def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False, resize_in_yv12=False):
319         self.img_resize = img_resize
320         self.scale_range = scale_range
321         self.is_flow = is_flow
322         self.center_crop = center_crop
323         self.resize_in_yv12 = resize_in_yv12
325     @staticmethod
326     def get_params(img, img_resize, scale_range, center_crop, resize_in_yv12 = False):
327         in_h, in_w = img.shape[:2]
328         out_h, out_w = img_resize
329         if resize_in_yv12:
330             #to make U,V as multiple of 4 shape to properly represent in YV12 format
331             round_or_align4 = lambda x: ((int(x)//4)*4)
332         else:
333             round_or_align4 = lambda x: round(x)
334         # this random scaling is w.r.t. the output size
335         if (np.random.random() < 0.5):
336             resize_h = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_h))
337             resize_w = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_w))
338         else:
339             resize_h, resize_w = out_h, out_w
341         # crop params w.r.t the scaled size
342         out_r = (resize_h - out_h)//2 if center_crop else np.random.randint(resize_h - out_h + 1)
343         out_c = (resize_w - out_w)//2 if center_crop else np.random.randint(resize_w - out_w + 1)
344         return out_r, out_c, out_h, out_w, resize_h, resize_w
346     def __call__(self, images, targets):
347         out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop,
348                                                                          resize_in_yv12 = self.resize_in_yv12)
350         def func_img(img, img_idx):
351             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
352             img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12)
353             return img
355         def func_tgt(img, img_idx):
356             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
357             img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
358             return img
360         images = ImageTransformUtils.apply_to_list(func_img, images)
361         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
362         return images, targets
365 class RandomCropScale(object):
366     """Crop the Image to random size and scale to the given resolution"""
367     def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False, resize_in_yv12=False):
368         self.size = size if (type(size) in (list,tuple)) else (size, size)
369         self.crop_range = crop_range
370         self.is_flow = is_flow
371         self.center_crop = center_crop
372         self.resize_in_yv12 = resize_in_yv12
374     @staticmethod
375     def get_params(img, crop_range, center_crop):
376         h_orig = img.shape[0]; w_orig = img.shape[1]
377         r_wh = (w_orig/h_orig)
378         ratio =  (r_wh/2.0, 2.0*r_wh)
379         for attempt in range(10):
380             area = h_orig * w_orig
381             target_area = random.uniform(*crop_range) * area
382             aspect_ratio = random.uniform(*ratio)
383             w = int(round(math.sqrt(target_area * aspect_ratio)))
384             h = int(round(math.sqrt(target_area / aspect_ratio)))
385             if (h <= h_orig) and (w <= w_orig):
386                 i = (h_orig - h)//2 if center_crop else random.randint(0, h_orig - h)
387                 j = (w_orig - w)//2 if center_crop else random.randint(0, w_orig - w)
388                 return i, j, h, w
390         # Fallback: entire image
391         return 0, 0, h_orig, w_orig
393     def __call__(self, images, targets):
394         out_r, out_c, out_h, out_w = self.get_params(images[0], self.crop_range, self.center_crop)
396         def func_img(img, img_idx):
397             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
398             img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12)
399             return img
401         def func_tgt(img, img_idx):
402             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
403             img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
404             return img
406         images = ImageTransformUtils.apply_to_list(func_img, images)
407         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
409         return images, targets
412 class Scale(object):
413     def __init__(self, img_size, target_size=None, is_flow=None):
414         self.img_size = img_size
415         self.target_size = target_size if target_size else img_size
416         self.is_flow = is_flow
418     def __call__(self, images, targets):
419         if self.img_size is None:
420             return images, targets
422         def func_img(img, img_idx):
423             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
424             img = ImageTransformUtils.resize_img(img, self.img_size, interpolation=-1, is_flow=is_flow_img)
425             return img
427         def func_tgt(img, img_idx):
428             is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
429             img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST, is_flow=is_flow_tgt)
430             return img
432         images = ImageTransformUtils.apply_to_list(func_img, images)
433         targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
435         return images, targets
438 class CropRect(object):
439     def __init__(self, crop_rect):
440         self.crop_rect = crop_rect
442     def __call__(self, images, targets):
443         if self.crop_rect is None:
444             return images, targets
446         def func(img, tgt, img_idx):
447             img_size = img.shape
448             crop_rect = self.crop_rect
449             max_val = max(crop_rect)
450             t, l, h, w = crop_rect
451             if max_val <= 1:  # convert float into integer
452                 t = int(t * img_size[0] + 0.5)  # top
453                 l = int(l * img_size[1] + 0.5)  # left
454                 h = int(h * img_size[0] + 0.5)  # height
455                 w = int(w * img_size[1] + 0.5)  # width
456             else:
457                 t = int(t)  # top
458                 l = int(l)  # left
459                 h = int(h)  # height
460                 w = int(w)  # width
462             img = img[t:(t+h), l:(l+w)]
463             tgt = tgt[t:(t+h), l:(l+w)]
465             return img, tgt
467         images = ImageTransformUtils.apply_to_list(func, images)
468         targets = ImageTransformUtils.apply_to_list(func, targets)
470         return images, targets
473 class MaskTarget(object):
474     def __init__(self, mask_rect, mask_val):
475         self.mask_rect = mask_rect
476         self.mask_val = mask_val
478     def __call__(self, images, targets):
479         if self.mask_rect is None:
480             return images, targets
482         def func(img, tgt, img_idx):
483             img_size = img.shape
484             crop_rect = self.mask_rect
485             max_val = max(crop_rect)
486             t, l, h, w = crop_rect
487             if max_val <= 1:  # convert float into integer
488                 t = int(t * img_size[0] + 0.5)  # top
489                 l = int(l * img_size[1] + 0.5)  # left
490                 h = int(h * img_size[0] + 0.5)  # height
491                 w = int(w * img_size[1] + 0.5)  # width
492             else:
493                 t = int(t)  # top
494                 l = int(l)  # left
495                 h = int(h)  # height
496                 w = int(w)  # width
498             tgt[t:(t+h), l:(l+w)] = self.mask_val
499             return img, tgt
501         images = ImageTransformUtils.apply_to_list(func, images)
502         targets = ImageTransformUtils.apply_to_list(func, targets)
504         return images, targets
507 class NormalizeMeanStd(object):
508     def __init__(self, mean, std):
509         self.mean = mean
510         self.std = std
512     def __call__(self, images, target):
513         if isinstance(images, (list,tuple)):
514             images = [(img-self.mean)/self.std for img in images]
515         else:
516             images = (images-self.mean)/self.std
518         return images, target
521 class NormalizeMeanScale(object):
522     def __init__(self, mean, scale):
523         self.mean = mean
524         self.scale = scale
527     def __call__(self, images, target):
528         def func(imgs, img_idx):
529             if isinstance(imgs, (list,tuple)):
530                 imgs = [(img-self.mean)*self.scale for img in imgs]
531             else:
532                 imgs = (imgs-self.mean)*self.scale
533             #
534             return imgs
535         #
536         images = ImageTransformUtils.apply_to_list(func, images) \
537             if isinstance(images, (list,tuple)) else func(images)
538         return images, target