From d162e610d91fcdb7f2750182db83f4d46af3c7b8 Mon Sep 17 00:00:00 2001 From: Manu Mathew Date: Fri, 27 Mar 2020 18:23:04 +0530 Subject: [PATCH] quantization aware training - bugfix for merged weights becoming 0 (typically due to one bn weight becomming 0) --- .../vision/losses/segmentation_loss.py | 2 +- .../transforms/image_transforms_xv12.py | 293 ++++++++++++++++-- .../xnn/quantize/quant_train_utils.py | 41 ++- 3 files changed, 295 insertions(+), 41 deletions(-) diff --git a/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py b/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py index 9e8c76f..90b715c 100755 --- a/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py +++ b/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py @@ -5,7 +5,7 @@ import numpy as np import torch from .loss_utils import * -__all__ = ['segmentation_loss', 'segmentation_metrics'] +__all__ = ['segmentation_loss', 'segmentation_metrics', 'SegmentationMetricsCalc'] def cross_entropy2d(input, target, weight=None, ignore_index=None, size_average=True): diff --git a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py index 0c50cc2..b058ac7 100644 --- a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py +++ b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py @@ -81,37 +81,253 @@ class RGBtoYV12(object): self.is_flow = is_flow self.keep_rgb = keep_rgb + + def debug_print(self, img=None, enable=False): + if enable: + h = img.shape[0] * 2 // 3 + w = img.shape[1] + + print("-" * 32, "RGBtoYV12") + print("Y") + + print(img[0:5, 0:5]) + + print("V Odd Lines") + print(img[h:h + 5, 0:5]) + + print("V Even Lines") + print(img[h:h + 5, w // 2:w // 2 + 5]) + + print("U Odd Lines") + print(img[h + h // 4:h + h // 4 + 5, 0:5]) + + print("U Even Lines") + print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5]) + + print("-" * 32) + def __call__(self, images, target): for img_idx in range(len(images)): is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) if not is_flow_img: images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12) - h = images[img_idx].shape[0] * 2 // 3 - w = images[img_idx].shape[1] - debug_print = False - if debug_print: - print("-" * 32, "RGBtoYV12") - print("Y") - img = images[img_idx] - print(img[0:5, 0:5]) + self.debug_print(img = images[img_idx], enable=False) + return images, target - print("V Odd Lines") - print(img[h:h + 5, 0:5]) - print("V Even Lines") - print(img[h:h + 5, w // 2:w // 2 + 5]) +#convert from YV12 to YUV444 with optional resize +def yv12_to_yuv444(img=None, out_size=None): + in_w = img.shape[1] + in_h = (img.shape[0] * 2) // 3 + + y_h = in_h + in_uv_h = in_h // 4 + + Y = img[0:in_h, 0:in_w] + + V = img[y_h:y_h + in_uv_h, 0:in_w] + V = V.reshape(V.shape[0]*2, -1) + + U = img[y_h + in_uv_h:y_h + 2 * in_uv_h, 0:in_w] + U = U.reshape(U.shape[0] * 2, -1) + + #if op_size is none then use input size as outsize In that case this functio becomes YV12 to YUV444 + if out_size is None: + out_size = (in_h, in_w) + + out_h, out_w = out_size + + Y = cv2.resize(Y, (out_w, out_h), interpolation=cv2.INTER_NEAREST) + U = cv2.resize(U, (out_w, out_h), interpolation=cv2.INTER_NEAREST) + V = cv2.resize(V, (out_w, out_h), interpolation=cv2.INTER_NEAREST) + + yuv444 = np.zeros((in_h, in_w, 3), dtype=np.uint8) + yuv444[:, :, 0] = Y + yuv444[:, :, 1] = U + yuv444[:, :, 2] = V + return yuv444 + +def get_w_b_yuv_to_rgb(device=None, rounding=True): + offset = OFFSET if rounding else 0 + + uv_mean = 128 + w_yuv_to_rgb = torch.tensor([ITUR_BT_601_CY, 0, ITUR_BT_601_CVR, + ITUR_BT_601_CY, ITUR_BT_601_CUG, ITUR_BT_601_CVG, + ITUR_BT_601_CY, ITUR_BT_601_CUB, 0], dtype=torch.float, device=device).reshape(3,3) + + #print(w_yuv_to_rgb) + w_yuv_to_rgb = w_yuv_to_rgb/(1<> ITUR_BT_601_SHIFT +#g = ((y-16) * ITUR_BT_601_CY + OFFSET + ITUR_BT_601_CVG * (v-uv_mean) ) + ITUR_BT_601_CUG * (u-uv_mean) ) >> ITUR_BT_601_SHIFT +#b = ((y-16) * ITUR_BT_601_CY + OFFSET + ITUR_BT_601_CUB * (u-uv_mean) ) >> ITUR_BT_601_SHIFT + +#r = (y * ITUR_BT_601_CY + ITUR_BT_601_CVR * v + (OFFSET-ITUR_BT_601_CVR*uv_mean -16*ITUR_BT_601_CY) ) >> ITUR_BT_601_SHIFT +#g = (y * ITUR_BT_601_CY + ITUR_BT_601_CVG * v + ITUR_BT_601_CUG * u + (OFFSET- ITUR_BT_601_CVG*uv_mean-ITUR_BT_601_CUG*uv_mean-16*ITUR_BT_601_CY)) >> ITUR_BT_601_SHIFT +#b = (y * ITUR_BT_601_CY + ITUR_BT_601_CUB * u + (OFFSET-ITUR_BT_601_CUB*uv_mean-16*ITUR_BT_601_CY) ) >> ITUR_BT_601_SHIFT + +# w_yuv_to_rgb = np.array([ITUR_BT_601_CY, 0, ITUR_BT_601_CVR, +# ITUR_BT_601_CY, ITUR_BT_601_CUG, ITUR_BT_601_CVG, +# ITUR_BT_601_CY, ITUR_BT_601_CUB, 0]).reshape(3,3) +# +# b_yuv_to_rgb = np.array([OFFSET-ITUR_BT_601_CVR*uv_mean-16*ITUR_BT_601_CY, +# OFFSET-ITUR_BT_601_CVG*uv_mean-ITUR_BT_601_CUG*uv_mean-16*ITUR_BT_601_CY, +# OFFSET-ITUR_BT_601_CUB*uv_mean-16*ITUR_BT_601_CY]).reshape(3,1) + +#ref https://github.com/opencv/opencv/blob/8c0b0714e76efef4a8ca2a7c410c60e55c5e9829/modules/imgproc/src/color_yuv.simd.hpp#L1075 +ITUR_BT_601_CY = 1220542 +ITUR_BT_601_CUB = 2116026 +ITUR_BT_601_CUG = -409993 +ITUR_BT_601_CVG = -852492 +ITUR_BT_601_CVR = 1673527 +ITUR_BT_601_SHIFT = 20 +OFFSET = (1 << (ITUR_BT_601_SHIFT - 1)) + + +def report_stat(diff=None): + unique, counts = np.unique(diff, return_counts=True) + result = dict(zip(unique, counts)) + print(result) + counts_list = np.zeros(max(unique) + 1, dtype=np.int) + for (diff, count) in zip(unique, counts): + counts_list[diff] = count + + str_to_print = ','.join('%d' % x for x in counts_list) + print(str_to_print) + counts_list = (100.00 * counts_list) / sum(counts_list) + str_to_print = ','.join('%8.5f' % x for x in counts_list) + print(str_to_print) + +def compare_diff(tensor_ref=None, tensor=None, exact_comp=False, roi_h=None, roi_w=None, ch_axis_idx=2, auto_scale=False): + roi_h = [0,tensor.shape[1]] if roi_h is None else roi_h + roi_w = [0,tensor.shape[2]] if roi_w is None else roi_w + + if ch_axis_idx == 0: + # swap ch index as subsequent code assumes ch_idx to be 2 + tensor = np.moveaxis(tensor, 0,-1) + tensor_ref = np.moveaxis(tensor_ref, 0, -1) + + n_ch = tensor.shape[2] + if ch_axis_idx != 0 and ch_axis_idx != 2: + exit("wrong ch index in compare_diff()") + + # crop + tensor = tensor[roi_h[0]:roi_h[1], roi_w[0]:roi_w[1], :] + tensor_ref = tensor_ref[roi_h[0]:roi_h[1], roi_w[0]:roi_w[1], :] + + #needed for float arrays. Convert float arrays to int with range [-128, 128] + if auto_scale: + max_val = np.amax(np.abs(tensor_ref)) + scale = 128.0/max_val + tensor_ref = (tensor_ref * scale).astype(np.int) + tensor = (tensor * scale).astype(np.int) + + if exact_comp: + for ch in range(n_ch): + print("ch: ", ch, " matching: ", np.array_equal(tensor_ref[:, :, ch], tensor[:, :, ch])) + indices = np.where(tensor_ref[:, :, ch] != tensor[:, :, ch]) + for (idx_h, idx_w) in zip(indices[0],indices[1]): + print(tensor_ref[idx_h, idx_w, ch], " : ", tensor[idx_h, idx_w, ch]) + else: #if clip is not used it is not expected to match with ref so just find how many are differing + print(" Global stats:") + diff = np.abs(tensor_ref[:, :, :] - tensor[:, :, :]) + report_stat(diff) + for ch in range(n_ch): + print(" ========= ch:", ch) + diff = np.abs(tensor_ref[:, :, ch] - tensor[:, :, ch]) + report_stat(diff) + + +def image_padding(img=None, pad_vals=[0,0,0]): + img_padded = np.empty((img.shape[0] + 2, img.shape[1] + 2, img.shape[2]), dtype=img.dtype) + img_padded[:, :, 0] = pad_vals[0] + img_padded[:, :, 1] = pad_vals[1] + img_padded[:, :, 2] = pad_vals[2] + img_padded[1:-1, 1:-1, :] = img[:, :, :] + return img_padded + + +def opencv_yuv_to_rgb(yuv444=None, uv_mean=0, clip=False, rounding=True, matrix_based_implementation=False, + op_type=np.int): + + if matrix_based_implementation: + implementation_with_loops = False + rgb = np.empty_like(yuv444, dtype=op_type) + [w_yuv_to_rgb, b_yuv_to_rgb] = get_w_b_yuv_to_rgb(device='cpu', rounding=rounding) + + if implementation_with_loops: + for y in range(yuv444.shape[0]): + for x in range(yuv444.shape[1]): + yuv444_sample = yuv444[y,x,:].reshape(3,1).astype(np.float) + temp = w_yuv_to_rgb @ yuv444_sample + b_yuv_to_rgb.reshape(3,1) + rgb[y,x,:] = np.squeeze(temp) + else: + r = np.dot(yuv444, w_yuv_to_rgb[0]) + b_yuv_to_rgb[0] + g = np.dot(yuv444, w_yuv_to_rgb[1]) + b_yuv_to_rgb[1] + b = np.dot(yuv444, w_yuv_to_rgb[2]) + b_yuv_to_rgb[2] + rgb[:, :, 0] = r + rgb[:, :, 1] = g + rgb[:, :, 2] = b + #compare_diff(tensor_ref=rgb_matrix_based, tensor=rgb, exact_comp=True, ch_axis_idx=2) + else: + y = yuv444[:, :, 0].astype(dtype=np.int) - 16 + u = yuv444[:, :, 1].astype(dtype=np.int) - uv_mean + v = yuv444[:, :, 2].astype(dtype=np.int) - uv_mean + ruv = OFFSET + ITUR_BT_601_CVR * v + guv = OFFSET + ITUR_BT_601_CVG * v + ITUR_BT_601_CUG * u + buv = OFFSET + ITUR_BT_601_CUB * u + + y00 = np.maximum(y, 0) * ITUR_BT_601_CY if clip else y * ITUR_BT_601_CY + + r = (y00 + ruv) >> ITUR_BT_601_SHIFT + g = (y00 + guv) >> ITUR_BT_601_SHIFT + b = (y00 + buv) >> ITUR_BT_601_SHIFT + + if clip: + r = np.clip(r, 0, 255) + g = np.clip(g, 0, 255) + b = np.clip(b, 0, 255) + + rgb = np.empty(yuv444.shape, dtype=np.int) + rgb[:, :, 0] = r + rgb[:, :, 1] = g + rgb[:, :, 2] = b + + #compare_diff(rgb_ref=rgb, rgb=rgb_matrix_based, clip=False) + return rgb - print("U Odd Lines") - print(img[h + h // 4:h + h // 4 + 5, 0:5]) - print("U Even Lines") - print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5]) +class YV12toRGB(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb - print("-" * 32) + def __call__(self, images, target): + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12) return images, target - -class YV12toRGB(object): +#similar to YV12toRGB but without clip +class YV12toRGBWithoutClip(object): def __init__(self, is_flow=None, keep_rgb=False): self.is_flow = is_flow self.keep_rgb = keep_rgb @@ -120,18 +336,37 @@ class YV12toRGB(object): for img_idx in range(len(images)): is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) if not is_flow_img: - images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12) - debug_print = False - if debug_print: - print("shape after YV12toRGB ", images[img_idx].shape) - print("-" * 32, "YV12toRGB") - img = images[img_idx] - print(img[0:5, 0:5, 0]) - print(img[0:5, 0:5, 1]) - print(img[0:5, 0:5, 2]) - print("-" * 32) + yuv444 = yv12_to_yuv444(images[img_idx]) + images[img_idx] = opencv_yuv_to_rgb(yuv444=yuv444, uv_mean=128, matrix_based_implementation=False) + return images, target + + +# Padding around images boundaries +class ImagePadding(object): + def __init__(self, is_flow=None, keep_rgb=False, pad_vals=[0,0,0]): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + self.pad_vals = pad_vals + + def __call__(self, images, target): + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + images[img_idx] = image_padding(img=images[img_idx], pad_vals=self.pad_vals) return images, target +#YV12 to YUV444 +class YV12toYUV444(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + + def __call__(self, images, target): + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + images[img_idx] = yv12_to_yuv444(images[img_idx]) + return images, target class YV12toNV12(object): def __init__(self, is_flow=None, keep_rgb=False): diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py index c91a6aa..553f2fa 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py @@ -235,7 +235,9 @@ class QuantTrainPAct2(layers.PAct2): merged_bias = (conv_bias - bn.running_mean) * merged_scale + bn_bias merged_weight = conv.weight * merged_scale.view(-1, 1, 1, 1) # - merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale.sign() + merged_scale_sign = merged_scale.sign() + merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1 + merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign merged_scale_inv = 1.0 / merged_scale_eps # elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]): @@ -282,13 +284,14 @@ class QuantTrainPAct2(layers.PAct2): else: clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight) # - width_min, width_max, bias_width_min, bias_width_max = self.get_widths_w() + width_min, width_max = self.get_widths_w() # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2 merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym') # if (self.quantize_bias): - bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_bias_clips_scale_w(merged_bias) + bias_width_min, bias_width_max = self.get_widths_bias() + bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias) # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2 merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym') # @@ -318,9 +321,9 @@ class QuantTrainPAct2(layers.PAct2): return conv, merged_weight, merged_bias - def get_clips_w(self, weight): + def get_clips_w(self, tensor): # find the clip values - w_min, w_max = utils.extrema_fast(weight.data, percentile_range_shrink=self.range_shrink_weights) + w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights) clip_max = torch.max(torch.abs(w_min), torch.abs(w_max)) clip_max = torch.clamp(clip_max, min=self.eps) # in range learning mode + training - this power2 is taken care in the quantize function @@ -329,20 +332,30 @@ class QuantTrainPAct2(layers.PAct2): clip_min2 = -clip_max2 return (clip_min2, clip_max2) + # bias uses the same kind of clips + get_clips_bias = get_clips_w + def get_clips_scale_w(self, weight): # convert to scale - clip_min, clip_max = self.get_clips_w(weight=weight) - width_min, width_max, _, _ = self.get_widths_w() + clip_min, clip_max = self.get_clips_w(weight) + width_min, width_max = self.get_widths_w() scale2 = (width_max / clip_max) scale2 = torch.clamp(scale2, min=self.eps) scale_inv2 = scale2.pow(-1.0) return (clip_min, clip_max, scale2, scale_inv2) + # in reality, bias quantization will also depend on the activation scale # this is not perfect - just a quick and dirty quantization for bias - def get_bias_clips_scale_w(self, bias): - return self.get_clips_scale_w(bias) + def get_clips_scale_bias(self, bias): + # convert to scale + clip_min, clip_max = self.get_clips_bias(bias) + width_min, width_max = self.get_widths_bias() + scale2 = (width_max / clip_max) + scale2 = torch.clamp(scale2, min=self.eps) + scale_inv2 = scale2.pow(-1.0) + return (clip_min, clip_max, scale2, scale_inv2) def get_widths_w(self): @@ -350,11 +363,17 @@ class QuantTrainPAct2(layers.PAct2): bw = (self.bitwidth_activations - 1) width_max = np.power(2.0, bw) width_min = -width_max + # return + return (width_min, width_max) + + + def get_widths_bias(self): # bias - bias_width_max = np.power(2.0, 2*self.bitwidth_activations-1) + bitwidth_bias = (2*self.bitwidth_activations) + bias_width_max = np.power(2.0, bitwidth_bias-1) bias_width_min = -bias_width_max # return - return (width_min, width_max, bias_width_min, bias_width_max) + return (bias_width_min, bias_width_max) # activation utility functions -- 2.39.2