torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / transforms / image_transform_utils.py
2 import numpy as np
3 import cv2
4 import torch
5 import types
7 class Compose(object):
8     """ Composes several co_transforms together.
9     For example:
10     >>> co_transforms.Compose([
11     >>>     co_transforms.CenterCrop(10),
12     >>>     co_transforms.ToTensor(),
13     >>>  ])
14     """
16     def __init__(self, t):
17         self.co_transforms = t
19     def extend(self, t):
20         self.co_transforms.extend(t)
22     def insert(self, index, t):
23         self.co_transforms.insert(index, t)
25     def write_img(self, img=[], ch_num=-1, name='', en=False):
26         if en == False:
27             return
28         #name = './data/checkpoints/tiad_interest_pt_descriptor/debug/{:02d}.jpg'.format(aug_idx)
29         scale_range = 255.0 / np.max(img)
30         img = np.clip(img * scale_range, 0.0, 255.0)
31         img = np.asarray(img, 'uint8')
33         non_zero_el = cv2.countNonZero(img)
35         print("non zero element: {}".format(non_zero_el))
36         cv2.imwrite('{}_nz{}.jpg'.format(name, non_zero_el), img)
38     def __call__(self, input, target):
39         if self.co_transforms:
40             for aug_idx, t in enumerate(self.co_transforms):
41                 if t:
42                     input,target = t(input,target)
44         return input,target
48 class Bypass(object):
49     def __init__(self):
50         pass
52     def __call__(self, images, targets):
53         return images,targets
56 class Lambda(object):
57     """Applies a lambda as a transform"""
59     def __init__(self, lambd):
60         assert isinstance(lambd, types.LambdaType)
61         self.lambd = lambd
63     def __call__(self, input,target):
64         return self.lambd(input,target)
67 class ImageTransformUtils(object):
68     @staticmethod
69     def apply_to_list(func, inputs):
70         for img_idx in range(len(inputs)):
71             inputs[img_idx] = func(inputs[img_idx], img_idx)
73         return inputs
75     @staticmethod
76     def crop(img, r, c, h, w):
77         img = img[r:(r+h), c:(c+w),...] if (len(img.shape)>2) else img[r:(r+h), c:(c+w)]
78         return img
80     @staticmethod
81     def resize_fast(img, output_size_rc, interpolation=-1):
82         in_h, in_w = img.shape[:2]
83         out_h, out_w = output_size_rc
84         if interpolation<0:
85             interpolation = cv2.INTER_AREA if ((out_h<in_h) or (out_w<in_w)) else cv2.INTER_LINEAR
87         img = cv2.resize(img, (out_w,out_h), interpolation=interpolation) #opencv expects size in (w,h) format
88         img = img[...,np.newaxis] if len(img.shape) < 3 else img
89         return img
91     @staticmethod
92     def resize_img(img, size, interpolation=-1, is_flow=False):
93         #if (len(img.shape) == 3) and (img.shape[2] == 1 or img.shape[2] == 3):
94         #    return __class__.resize_fast(img, size, interpolation)
96         in_h, in_w = img.shape[:2]
97         out_h, out_w = size
98         if interpolation<0:
99             interpolation = cv2.INTER_AREA if ((out_h<in_h) or (out_w<in_w)) else cv2.INTER_LINEAR
101         # opencv handles planar, 1 or 3 channel images
102         img = img[...,np.newaxis] if len(img.shape) < 3 else img
103         num_chans = img.shape[2]
104         img = np.concatenate([img]+[img[...,0:1]]*(3-num_chans), axis=2) if num_chans<3 else img
105         img = cv2.resize(img, (out_w, out_h), interpolation=interpolation)
106         img = img[...,:num_chans]
108         if is_flow:
109             ratio_h = out_h / in_h
110             ratio_w = out_w / in_w
111             img = ImageTransformUtils.scale_flow(img, ratio_w, ratio_h)
113         return img
115     @staticmethod
116     def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False):
117         if resize_in_yv12:
118             yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
119             yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
120             img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
121         else:
122             img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
123         #
124         img = ImageTransformUtils.crop(img, r, c, h, w)
125         return img
127     @staticmethod
128     def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False):
129         img = ImageTransformUtils.crop(img, r, c, h, w)
130         if resize_in_yv12:
131             yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
132             yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
133             img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
134         else:
135             img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
136         #
137         return img
139     @staticmethod
140     def rotate_img(img, angle, interpolation=-1):
141         h, w = img.shape[:2]
142         rmat2x3 = cv2.getRotationMatrix2D(center=(w//2,h//2), angle=angle, scale=1.0)
143         interpolation = cv2.INTER_NEAREST if interpolation < 0 else interpolation
144         img = cv2.warpAffine(img, rmat2x3, (w,h), flags=interpolation)
145         return img
147     @staticmethod
148     def array_to_tensor(array):
149         """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
150         assert(isinstance(array, np.ndarray))
151         # put it from HWC to CHW format
152         array = np.transpose(array, (2, 0, 1))
153         if len(array.shape) < 3:
154             array = array[np.newaxis, ...]
155         #
156         tensor = torch.from_numpy(array)
157         return tensor.float()
159     @staticmethod
160     def scale_flow(flow, ratio_x, ratio_y):
161         flow = flow.astype(np.float32)
162         flow[...,0] *= ratio_x
163         flow[...,1] *= ratio_y
164         return flow
166     @staticmethod
167     def scale_flows(inputs, ratio_x, ratio_y, is_flow):
168         for img_idx in range(len(inputs)):
169             if is_flow and is_flow[img_idx]:
170                 inputs[img_idx] = inputs[img_idx].astype(np.float32)
171                 inputs[img_idx][...,0] *= ratio_x
172                 inputs[img_idx][...,1] *= ratio_y
173         #
174         return inputs
177     #############################################################
178     # functions for nv12
180     @staticmethod
181     def resize_img_yv12(img, size, interpolation=-1, is_flow=False):
182         #if (len(img.shape) == 3) and (img.shape[2] == 1 or img.shape[2] == 3):
183         #    return __class__.resize_fast(img, size, interpolation)
184         debug_print = False
185         in_w = img.shape[1]
186         in_h = (img.shape[0] * 2) // 3
187         y_h = in_h
188         uv_h = in_h // 4
189         u_w = in_w // 2
191         Y = img[0:in_h, 0:in_w]
192         V = img[y_h:y_h + uv_h, 0:in_w]
193         #print(V[0:2,0:8])
194         #print(V[0:2, u_w:u_w+8])
195         V = V.reshape(V.shape[0]*2, -1)
196         #print(V[0:2, 0:8])
197         #print(V[0:2, u_w:u_w + 8])
198         U = img[y_h + uv_h:y_h + 2 * uv_h, 0:in_w]
199         U = U.reshape(U.shape[0] * 2, -1)
201         out_h, out_w = size
202         if interpolation < 0:
203             interpolation = cv2.INTER_AREA if ((out_h < in_h) or (out_w < in_w)) else cv2.INTER_LINEAR
205         Y = cv2.resize(Y, (out_w, out_h), interpolation=interpolation)
206         U = cv2.resize(U, (out_w//2, out_h//2), interpolation=interpolation)
207         V = cv2.resize(V, (out_w//2, out_h//2), interpolation=interpolation)
209         img = np.zeros((out_h*3//2, out_w), dtype='uint8')
210         op_uv_h = out_h // 4
212         img[0:out_h, 0:out_w] = Y[:, :]
213         #print(V[0:2,0:8])
214         V = V.reshape(V.shape[0] // 2, -1)
215         #print(V[0:1,0:8])
216         #print(V[0:1, op_u_w:op_u_w+8])
217         img[out_h:out_h + op_uv_h, 0:out_w] = V
218         U = U.reshape(U.shape[0] // 2, -1)
219         img[out_h + op_uv_h:out_h + 2 * op_uv_h, 0:out_w] = U
221         if debug_print:
222             h = img.shape[0] * 2 // 3
223             w = img.shape[1]
224             print("-" * 32, "Resize in YV12")
225             print("Y")
226             print(img[0:5, 0:5])
228             print("V Odd Lines")
229             print(img[h:h + 5, 0:5])
231             print("V Even Lines")
232             print(img[h:h + 5, w // 2:w // 2 + 5])
234             print("U Odd Lines")
235             print(img[h + h // 4:h + h // 4 + 5, 0:5])
237             print("U Even Lines")
238             print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5])
240             print("-" * 32)
242         if is_flow:
243             ratio_h = out_h / in_h
244             ratio_w = out_w / in_w
245             img = ImageTransformUtils.scale_flow(img, ratio_w, ratio_h)
247         return img
250     # @staticmethod
251     # def resize_and_crop_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False):
252     #     yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
253     #     yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
254     #     img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGb_YV12)
255     #     img = ImageTransformUtils.crop(img, r, c, h, w)
256     #     return img
257     #
258     # @staticmethod
259     # def crop_and_resize_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False):
260     #     img = ImageTransformUtils.crop(img, r, c, h, w)
261     #     yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
262     #     yv12 = ImageTransformUtils.resize_img_yv12(img, size, interpolation, is_flow)
263     #     img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
264     #     return img