]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/utils/tensor_utils.py
70c90c54118edeeb00f45b13a9c347d05cc23933
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / tensor_utils.py
1 import math
2 import random
3 import numpy as np
4 import torch
5 import scipy
6 import warnings
7 import cv2
8 from ..layers import functional
9 from . import image_utils
12 ###############################################################
13 # signed_log: a logarithmic representation with sign
14 def signed_log(x, base):
15     def log_fn(x):
16         return torch.log2(x)/np.log2(base)
17     #
18     # not using torch.sign as it doesn't have gradient
19     sign = (x < 0) * (-1) + (x >= 0) * (+1)
20     y = log_fn(torch.abs(x) + 1.0)
21     y = y * sign
22     return y
25 # convert back to linear from signed_log
26 def signed_pow(x, base):
27     # not using torch.sign as it doesn't have gradient
28     sign = (x < 0) * (-1) + (x >= 0) * (+1)
29     y = torch.pow(base, torch.abs(x)) - 1.0
30     y = y * sign
31     return y
34 ###############################################################
35 def extrema_fast(src, percentile_range_shrink=0.0, sigma=0.0, fast_mode=True):
36     return extrema(src, percentile_range_shrink, sigma, fast_mode)
39 def extrema(src, percentile_range_shrink=0.0, sigma=0.0, fast_mode=False):
40     if percentile_range_shrink == 0 and sigma == 0:
41         mn = src.min()
42         mx = src.max()
43         return (mn, mx)
44     elif percentile_range_shrink:
45         # downsample for fast_mode
46         fast_stride = 2
47         fast_stride2 = fast_stride*2
48         if fast_mode and len(src.size())==4 and (src.size(2)>fast_stride2) and (src.size(3)>fast_stride2):
49             r_start = random.randint(0, fast_stride-1)
50             c_start = random.randint(0, fast_stride-1)
51             src = src[..., r_start::fast_stride, c_start::fast_stride]
52         #
53         mn = src.min()
54         mx = src.max()
55         if mn ==0 and mx == 0:
56             return mn, mx
57         #
59         # compute percentile_range_shrink based min/max
60         # frequency - bincount can only operate on unsigned
61         num_bins = 255.0
62         cum_freq = float(100.0)
63         offset = mn
64         range_val = torch.abs(mx - mn)
65         mult_factor = (num_bins / range_val)
66         tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
67         tensor_int = functional.round_g(tensor_int).int()
69         # numpy version
70         #hist = np.bincount(tensor_int.cpu().numpy())
71         #hist_sum = np.sum(hist)
72         #hist_array = hist.astype(np.float32) * cum_freq / float(hist_sum)
74         # torch version
75         hist = torch.bincount(tensor_int)
76         hist_sum = torch.sum(hist)
77         hist = hist.float() * cum_freq / hist_sum.float()
78         hist_array = hist.cpu().numpy()
80         new_mn_scaled, new_mx_scaled = extrema_hist_search(hist_array, percentile_range_shrink)
81         new_mn = (new_mn_scaled / mult_factor) + offset
82         new_mx = (new_mx_scaled / mult_factor) + offset
84         # take care of floating point inaccuracies that can
85         # increase the range (in rare cases) beyond the actual range.
86         new_mn = max(mn, new_mn)
87         new_mx = min(mx, new_mx)
88         return new_mn, new_mx
89     elif sigma:
90         mean = torch.mean(src)
91         std = torch.std(src)
92         mn = mean - sigma*std
93         mx = mean + sigma*std
94         return mn, mx
95     else:
96         assert False, 'unknown extrema computation mode'
99 # this code is not parallelizable. better to pass a numpy array
100 def extrema_hist_search(hist_array, percentile_range_shrink):
101     new_mn_scaled = 0
102     new_mx_scaled = len(hist_array) - 1
103     hist_sum_left = 0.0
104     hist_sum_right = 0.0
105     for h_idx in range(len(hist_array)):
106         r_idx = len(hist_array) - 1 - h_idx
107         hist_sum_left += hist_array[h_idx]
108         hist_sum_right += hist_array[r_idx]
109         if hist_sum_left < percentile_range_shrink:
110             new_mn_scaled = h_idx
111         if hist_sum_right < percentile_range_shrink:
112             new_mx_scaled = r_idx
113         #
114     #
115     return new_mn_scaled, new_mx_scaled
118 ##################################################################
119 def check_sizes(input, input_name, expected):
120     condition = [input.ndimension() == len(expected)]
121     for i,size in enumerate(expected):
122         if size.isdigit():
123             condition.append(input.size(i) == int(size))
124     assert(all(condition)), "wrong size for {}, expected {}, got  {}".format(input_name, 'x'.join(expected), list(input.size()))
127 ###########################################################################
128 def tensor2img(tensor, adjust_range=True, min_value = None, max_value=None):
129     if tensor.ndimension() < 3:
130         tensor = tensor.unsqueeze(0)
131     if tensor.ndimension() < 4:
132         tensor = tensor.unsqueeze(0)
133     if min_value is None:
134         min_value = tensor.min()
135     if max_value is None:
136         max_value = tensor.max()
137     range = max_value-min_value
138     array = (255*(tensor - min_value)/range).clamp(0,255) if adjust_range else tensor
139     if array.size(1) >= 3:
140         img = torch.stack((array[0,0], array[0,1], array[0,2]), dim=2)
141     else:
142         img = array[0,0]
143     return img.cpu().data.numpy().astype(np.uint8)
146 def flow2rgb(flow_map, max_value):
147     global args
148     _, h, w = flow_map.shape
149     #flow_map[:,(flow_map[0] == 0) & (flow_map[1] == 0)] = float('nan')
150     rgb_map = np.ones((h,w,3)).astype(np.float32)
151     if max_value is not None:
152         normalized_flow_map = flow_map / max_value
153     else:
154         normalized_flow_map = flow_map / (np.abs(flow_map).max())
155     rgb_map[:,:,0] += normalized_flow_map[0]
156     rgb_map[:,:,1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
157     rgb_map[:,:,2] += normalized_flow_map[1]
158     return rgb_map.clip(0,1)
161 def flow2hsv(flow_map, max_value=128, scale_fact=8, confidence=False):
162     global args
163     _, h, w = flow_map.shape
164     hsv = np.zeros((h, w, 3)).astype(np.float32)
166     mag = np.sqrt(flow_map[0]**2 + flow_map[1]**2)
167     phase = np.arctan2(flow_map[1], flow_map[0])
168     phase = np.mod(phase/(2*np.pi), 1)
170     hsv[:, :, 0] = phase*360
171     hsv[:, :, 1] = (mag*scale_fact/max_value).clip(0, 1)
172     hsv[:, :, 2] = (scale_fact - hsv[:, :, 1]).clip(0, 1)
173     rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
174     if confidence:
175         return rgb * flow_map[2] > 128
176     else:
177         return rgb
180 def tensor2array(tensor, max_value=255.0, colormap='rainbow', input_blend=None):
181     max_value = float(tensor.max()) if max_value is None else max_value
183     if tensor.ndimension() == 2 or tensor.size(0) == 1:
184         try:
185             import cv2
186             if cv2.__version__.startswith('2') :
187                 color_cvt = cv2.cv.CV_BGR2RGB
188             else:  # 3.x,4,x
189                 color_cvt = cv2.COLOR_BGR2RGB
190             #
191             if colormap == 'rainbow':
192                 colormap = cv2.COLORMAP_RAINBOW
193             elif colormap == 'magma': # >=3.4.8
194                 colormap = cv2.COLORMAP_MAGMA
195             elif colormap == 'bone':
196                 colormap = cv2.COLORMAP_BONE
197             elif colormap == 'plasma': # >=4.1
198                 colormap = cv2.COLORMAP_PLASMA
199             elif colormap == 'turbo': # >=4.1.2
200                 colormap = cv2.COLORMAP_TURBO
201             #
202             array = (255.0*tensor.squeeze().numpy()/max_value).clip(0, 255).astype(np.uint8)
203             colored_array = cv2.applyColorMap(array, colormap)
204             array = cv2.cvtColor(colored_array, color_cvt).astype(np.float32) / 255.0
205         except ImportError:
206             if tensor.ndimension() == 2:
207                 tensor.unsqueeze_(2)
208             #
209             array = (tensor.expand(tensor.size(0), tensor.size(1), 3).numpy()/max_value).clip(0,1)
210     elif tensor.ndimension() == 3:
211         assert(tensor.size(0) == 3)
212         array = 0.5 + tensor.numpy().transpose(1, 2, 0)*0.5
213     #
214     if input_blend is not None:
215         array = image_utils.chroma_blend(input_blend, array)
216     #
217     return array
220 def tensor2img(tensor, max_value=63535):
221     array = (63535*tensor.numpy()/max_value).clip(0, 63535).astype(np.uint16)
222     if tensor.ndimension() == 3:
223         assert (array.size(0) == 3)
224         array = array.transpose(1, 2, 0)
225     return array
228 ##################################################################
229 def inverse_warp_flow(img, flow, padding_mode='zeros'):
230     """
231     Inverse warp a source image to the target image plane.
233     Args:
234         img: the source image (where to sample pixels) -- [B, 3, H, W]
235         flow: flow to be used for warping
236     Returns:
237         Source image warped to the target image plane
238     """
240     #check_sizes(img, 'img', 'B3HW')
241     check_sizes(flow, 'flow', 'B2HW')
243     b,c,h,w = img.size()
244     h2 = (h-1.0)/2.0
245     w2 = (w-1.0)/2.0
247     pixel_coords = img_set_id_grid_(img)
249     src_pixel_coords = pixel_coords + flow
251     x_coords = src_pixel_coords[:, 0]
252     x_coords = (x_coords - w2) / w2
254     y_coords = src_pixel_coords[:, 1]
255     y_coords = (y_coords - h2) / h2
257     src_pixel_coords = torch.stack((x_coords, y_coords), dim=3)
258     projected_img = torch.nn.functional.grid_sample(img, src_pixel_coords, \
259                           mode='bilinear', padding_mode=padding_mode)
261     return projected_img
264 def img_set_id_grid_(img):
265     b, c, h, w = img.size()
266     x_range = torch.Tensor(torch.arange(0, w).view(1, 1, w).expand(1,h,w)).type_as(img)  # [1, H, W]
267     y_range = torch.Tensor(torch.arange(0, h).view(1, h, 1).expand(1,h,w)).type_as(img)  # [1, H, W]
268     pixel_coords = torch.stack((x_range, y_range), dim=1).float()  # [1, 2, H, W]
269     return pixel_coords
272 def crop_like(input, target):
273     if target is None or (input.size()[2:] == target.size()[2:]):
274         return input
275     else:
276         return input[:, :, :target.size(2), :target.size(3)]
279 def crop_alike(input, target):
280     global crop_alike_warning_done
281     if target is None or (input.size() == target.size()):
282         return input, target
284     warnings.warning('=> tensor dimension mismatch. input:{}, target:{}. cropping'.ormat(input.size(),target.size()))
286     min_ch = min(input.size(1), target.size(1))
287     min_h = min(input.size(2), target.size(2))
288     min_w = min(input.size(3), target.size(3))
289     h_offset_i = h_offset_t = w_offset_i = w_offset_t = 0
290     if input.size(2) > target.size(2):
291         h_offset_i = (input.size(2) - target.size(2))//2
292     else:
293         h_offset_t = (target.size(2) - input.size(2))//2
295     if input.size(3) > target.size(3):
296         w_offset_i = (input.size(3) - target.size(3))//2
297     else:
298         w_offset_t = (target.size(3) - input.size(3))//2
300     input = input[:, :min_ch, h_offset_i:(h_offset_i+min_h), w_offset_i:(w_offset_i+min_w)]
301     target = target[:, :min_ch, h_offset_t:(h_offset_t+min_h), w_offset_t:(w_offset_t+min_w)]
303     return input, target
306 def align_channels_(x,y):
307     chan_x = x.size(1)
308     chan_y = y.size(1)
309     if chan_x != chan_y:
310         chan_min = min(chan_x, chan_y)
311         x = x[:,:chan_min,...]
312         if len(x.size()) < 4:
313             x = torch.unsqueeze(x,dim=1)
314         y = y[:,:chan_min,...]
315         if len(y.size()) < 4:
316             y = torch.unsqueeze(y,dim=1)
317     return x, y
320 def debug_dump_tensor(tensor, image_name, adjust_range=True):
321     img = tensor2img(tensor, adjust_range=adjust_range)
322     scipy.misc.imsave(image_name, img)