From: Manu Mathew Date: Thu, 12 Mar 2020 10:45:36 +0000 (+0530) Subject: torch.nn.ReLU is the recommended activation module. removed the custom defined module... X-Git-Url: https://git.ti.com/gitweb?p=jacinto-ai%2Fpytorch-jacinto-ai-devkit.git;a=commitdiff_plain;h=ab84ec9f9e27a32f45a2e6372f32ba3fe9b1a799 torch.nn.ReLU is the recommended activation module. removed the custom defined module called ReLUN - if fixed range activation module is needed torch.nn.Hardtanh can be used. --- diff --git a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py index 605cef2..e9bf732 100644 --- a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py +++ b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py @@ -52,7 +52,7 @@ def get_config(): args.model_config.freeze_decoder = False # do not update decoder weights args.model_config.multi_task_type = 'learned' # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm'] args.model_config.target_input_ratio = 1 # Keep target size same as input size - + args.model_config.input_nv12 = False args.model = None # the model itself can be given from ouside args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified args.dataset_name = 'cityscapes_segmentation' # dataset type @@ -169,6 +169,7 @@ def get_config(): args.tensorboard_enable = True # en/disable of TB writing args.print_train_class_iou = False args.print_val_class_iou = False + args.freeze_layers = None return args @@ -310,19 +311,24 @@ def main(args): ################################################# pretrained_data = None model_surgery_quantize = False + pretrained_data = None if args.pretrained and args.pretrained != "None": - if isinstance(args.pretrained, dict): - pretrained_data = args.pretrained - else: - if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'): - pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads') + pretrained_data = [] + pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained] + for p in pretrained_files: + if isinstance(p, dict): + p_data = p else: - pretrained_file = args.pretrained + if p.startswith('http://') or p.startswith('https://'): + p_file = vision.datasets.utils.download_url(p, './data/downloads') + else: + p_file = p + # + print(f'=> loading pretrained weights file: {p}') + p_data = torch.load(p_file) # - print(f'=> using pre-trained weights from: {args.pretrained}') - pretrained_data = torch.load(pretrained_file) - # - model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False + pretrained_data.append(p_data) + model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False # ################################################# @@ -361,7 +367,9 @@ def main(args): # load pretrained model if pretrained_data is not None: - xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict) + for (p_data,p_file) in zip(pretrained_data, pretrained_files): + print("=> using pretrained weights from: {}".format(p_file)) + xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict) # ################################################# @@ -594,6 +602,18 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ if args.freeze_bn: xnn.utils.freeze_bn(model) # + + #freeze layers + if args.freeze_layers is not None: + # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name' + # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0' 'encoder.0.1' will be frozen + for freeze_layer_name in args.freeze_layers: + for name, module in model.named_modules(): + if freeze_layer_name in name: + xnn.utils.print_once("Freezing the module : {}".format(name)) + module.eval() + for param in module.parameters(): + param.requires_grad = False ########################## for task_dx, task_losses in enumerate(args.loss_modules): @@ -621,7 +641,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ lr = scheduler.get_lr()[0] - input_list = [img.cuda() for img in inputs] + input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs] target_list = [tgt.cuda(non_blocking=True) for tgt in targets] target_sizes = [tgt.shape for tgt in target_list] batch_size_cur = target_sizes[0][0] @@ -632,7 +652,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs] # upsample output to target resolution - task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode) + if args.upsample_mode is not None: + task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode) if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1: args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model) @@ -774,7 +795,7 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer): ########################## for iter, (inputs, targets) in enumerate(val_loader): data_time.update(time.time() - end_time) - input_list = [j.cuda() for j in inputs] + input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs] target_list = [j.cuda(non_blocking=True) for j in targets] target_sizes = [tgt.shape for tgt in target_list] batch_size_cur = target_sizes[0][0] @@ -782,8 +803,10 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer): # compute output task_outputs = model(input_list) - task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs] - task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode) + + task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs] + if args.upsample_mode is not None: + task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode) if args.print_val_class_iou: metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \ @@ -864,13 +887,20 @@ def get_model_orig(model): def create_rand_inputs(args, is_cuda): dummy_input = [] - for i_ch in args.model_config.input_channels: - x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1])) - x = x.cuda() if is_cuda else x - dummy_input.append(x) - # - return dummy_input + if not args.model_config.input_nv12: + for i_ch in args.model_config.input_channels: + x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1])) + x = x.cuda() if is_cuda else x + dummy_input.append(x) + else: #nv12 + for i_ch in args.model_config.input_channels: + y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1])) + uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1])) + y = y.cuda() if is_cuda else y + uv = uv.cuda() if is_cuda else uv + dummy_input.append([y,uv]) + return dummy_input def count_flops(args, model): is_cuda = next(model.parameters()).is_cuda @@ -922,15 +952,23 @@ def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writ write_prob = np.random.random() if (write_prob > write_freq): return - - batch_size = input_images[0].shape[0] + if args.model_config.input_nv12: + batch_size = input_images[0][0].shape[0] + else: + batch_size = input_images[0].shape[0] b_index = random.randint(0, batch_size - 1) input_image = None for img_idx, img in enumerate(input_images): - input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0)) - # convert back to original input range (0-255) - input_image = input_image / args.image_scale + args.image_mean + if args.model_config.input_nv12: + #convert NV12 to BGR for tensorboard + input_image = vision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index], + image_scale=args.image_scale, image_mean=args.image_mean) + else: + input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0)) + # convert back to original input range (0-255) + input_image = input_image / args.image_scale + args.image_mean + if args.is_flow and args.is_flow[0][img_idx]: #input corresponding to flow is assumed to have been generated by adding 128 flow = input_image - 128 diff --git a/modules/pytorch_jacinto_ai/vision/losses/loss_utils.py b/modules/pytorch_jacinto_ai/vision/losses/loss_utils.py index 21a04bc..2df9485 100644 --- a/modules/pytorch_jacinto_ai/vision/losses/loss_utils.py +++ b/modules/pytorch_jacinto_ai/vision/losses/loss_utils.py @@ -10,6 +10,8 @@ def l2_norm(x,y=None): diff = (x-y) if y is not None else x return torch.norm(diff,p=2,dim=1,keepdim=True) +def l2_norm_self(x,y=None): + return torch.norm(x,p=2,dim=1,keepdim=True) def l1_norm(x,y=None): if y is not None: diff --git a/modules/pytorch_jacinto_ai/vision/losses/norm_loss.py b/modules/pytorch_jacinto_ai/vision/losses/norm_loss.py index 10fb06e..0c317dd 100644 --- a/modules/pytorch_jacinto_ai/vision/losses/norm_loss.py +++ b/modules/pytorch_jacinto_ai/vision/losses/norm_loss.py @@ -15,6 +15,13 @@ class L2NormDiff(BasicNormLossModule): supervised_l2_loss = l2_norm_loss = L2NormDiff +#take absolute norm of tensor as loss instead of diff +class L2NormSelf(BasicNormLossModule): + def __init__(self, sparse=False, error_fn=l2_norm_self, error_name='L2NormSelf'): + super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name) +supervised_l2_loss_self = l2_norm_loss_self = L2NormSelf + + ######################################## class SmoothL1Diff(BasicElementwiseLossModule): def __init__(self, sparse=False, error_fn=smooth_l1_loss, error_name='SmoothL1Diff'): diff --git a/modules/pytorch_jacinto_ai/vision/models/multi_input_net.py b/modules/pytorch_jacinto_ai/vision/models/multi_input_net.py index 4e76058..b411c0d 100644 --- a/modules/pytorch_jacinto_ai/vision/models/multi_input_net.py +++ b/modules/pytorch_jacinto_ai/vision/models/multi_input_net.py @@ -7,11 +7,11 @@ from .resnet import resnet50_with_model_config try: from .mobilenetv2_ericsun_internal import * except: pass -try: from .mobilenetv2_gws_internal import * +try: from .mobilenetv2_internal import * except: pass __all__ = ['MultiInputNet', 'mobilenet_v2_tv_mi4', 'mobilenet_v2_tv_gws_mi4', 'mobilenet_v2_ericsun_mi4', - 'MobileNetV2TVMI4', 'ResNet50MI4'] + 'MobileNetV2TVMI4', 'MobileNetV2TVNV12MI4', 'ResNet50MI4'] ################################################### @@ -212,6 +212,16 @@ class MobileNetV2EricsunMI4(MultiInputNet): mobilenet_v2_ericsun_mi4 = MobileNetV2EricsunMI4 +# these are the real multi input blocks +class MobileNetV2TVNV12MI4(MultiInputNet): + def __init__(self, model_config): + model_config.num_input_blocks = 4 + model_config.fuse_channels = 24 + super().__init__(MobileNetV2TVNV12, model_config) +# +mobilenet_v2_tv_nv12_mi4 = MobileNetV2TVNV12MI4 + +# these are the real multi input blocks class MobileNetV2TVGWSMI4(MultiInputNet): def __init__(self, model_config): model_config.num_input_blocks = 4 diff --git a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py index a7da590..dda072e 100644 --- a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py +++ b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py @@ -60,7 +60,7 @@ class DeepLabV3LiteDecoder(torch.nn.Module): assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input' x_input = x[0] - in_shape = x_input.shape + in_shape = x_input[0].shape if isinstance(x_input, (list,tuple)) else x_input.shape # high res shortcut shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0]) diff --git a/modules/pytorch_jacinto_ai/vision/transforms/__init__.py b/modules/pytorch_jacinto_ai/vision/transforms/__init__.py index fe96ee7..1ef9803 100644 --- a/modules/pytorch_jacinto_ai/vision/transforms/__init__.py +++ b/modules/pytorch_jacinto_ai/vision/transforms/__init__.py @@ -1,2 +1,3 @@ from .transforms import * from . import image_transforms +from . import image_transforms_xv12 diff --git a/modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py b/modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py index 65e5012..c4a60c6 100644 --- a/modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py +++ b/modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py @@ -113,15 +113,27 @@ class ImageTransformUtils(object): return img @staticmethod - def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False): - img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow) + def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False): + if resize_in_yv12: + yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12) + yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow) + img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12) + else: + img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow) + # img = ImageTransformUtils.crop(img, r, c, h, w) return img @staticmethod - def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False): + def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False): img = ImageTransformUtils.crop(img, r, c, h, w) - img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow) + if resize_in_yv12: + yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12) + yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow) + img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12) + else: + img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow) + # return img @staticmethod @@ -161,3 +173,92 @@ class ImageTransformUtils(object): # return inputs + + ############################################################# + # functions for nv12 + + @staticmethod + def resize_img_yv12(img, size, interpolation=-1, is_flow=False): + #if (len(img.shape) == 3) and (img.shape[2] == 1 or img.shape[2] == 3): + # return __class__.resize_fast(img, size, interpolation) + debug_print = False + in_w = img.shape[1] + in_h = (img.shape[0] * 2) // 3 + y_h = in_h + uv_h = in_h // 4 + u_w = in_w // 2 + + Y = img[0:in_h, 0:in_w] + V = img[y_h:y_h + uv_h, 0:in_w] + #print(V[0:2,0:8]) + #print(V[0:2, u_w:u_w+8]) + V = V.reshape(V.shape[0]*2, -1) + #print(V[0:2, 0:8]) + #print(V[0:2, u_w:u_w + 8]) + U = img[y_h + uv_h:y_h + 2 * uv_h, 0:in_w] + U = U.reshape(U.shape[0] * 2, -1) + + out_h, out_w = size + if interpolation < 0: + interpolation = cv2.INTER_AREA if ((out_h < in_h) or (out_w < in_w)) else cv2.INTER_LINEAR + + Y = cv2.resize(Y, (out_w, out_h), interpolation=interpolation) + U = cv2.resize(U, (out_w//2, out_h//2), interpolation=interpolation) + V = cv2.resize(V, (out_w//2, out_h//2), interpolation=interpolation) + + img = np.zeros((out_h*3//2, out_w), dtype='uint8') + op_uv_h = out_h // 4 + + img[0:out_h, 0:out_w] = Y[:, :] + #print(V[0:2,0:8]) + V = V.reshape(V.shape[0] // 2, -1) + #print(V[0:1,0:8]) + #print(V[0:1, op_u_w:op_u_w+8]) + img[out_h:out_h + op_uv_h, 0:out_w] = V + U = U.reshape(U.shape[0] // 2, -1) + img[out_h + op_uv_h:out_h + 2 * op_uv_h, 0:out_w] = U + + if debug_print: + h = img.shape[0] * 2 // 3 + w = img.shape[1] + print("-" * 32, "Resize in YV12") + print("Y") + print(img[0:5, 0:5]) + + print("V Odd Lines") + print(img[h:h + 5, 0:5]) + + print("V Even Lines") + print(img[h:h + 5, w // 2:w // 2 + 5]) + + print("U Odd Lines") + print(img[h + h // 4:h + h // 4 + 5, 0:5]) + + print("U Even Lines") + print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5]) + + print("-" * 32) + + if is_flow: + ratio_h = out_h / in_h + ratio_w = out_w / in_w + img = ImageTransformUtils.scale_flow(img, ratio_w, ratio_h) + + return img + + + # @staticmethod + # def resize_and_crop_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False): + # yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12) + # yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow) + # img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGb_YV12) + # img = ImageTransformUtils.crop(img, r, c, h, w) + # return img + # + # @staticmethod + # def crop_and_resize_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False): + # img = ImageTransformUtils.crop(img, r, c, h, w) + # yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12) + # yv12 = ImageTransformUtils.resize_img_yv12(img, size, interpolation, is_flow) + # img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12) + # return img \ No newline at end of file diff --git a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py index 9b46e5c..fbf9ac5 100644 --- a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py +++ b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py @@ -48,9 +48,10 @@ class AlignImages(object): class ConvertToTensor(object): """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" def __call__(self, images, targets): - def func(img, img_idx): - img = ImageTransformUtils.array_to_tensor(img) - return img + def func(imgs, img_idx): + imgs = [ImageTransformUtils.array_to_tensor(img_plane) for img_plane in imgs] \ + if isinstance(imgs, list) else ImageTransformUtils.array_to_tensor(imgs) + return imgs images = ImageTransformUtils.apply_to_list(func, images) targets = ImageTransformUtils.apply_to_list(func, targets) @@ -314,20 +315,26 @@ class RandomColor2Gray(object): class RandomScaleCrop(object): """Randomly zooms images up to 15% and crop them to keep same size as before.""" - def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False): + def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False, resize_in_yv12=False): self.img_resize = img_resize self.scale_range = scale_range self.is_flow = is_flow self.center_crop = center_crop + self.resize_in_yv12 = resize_in_yv12 @staticmethod - def get_params(img, img_resize, scale_range, center_crop): + def get_params(img, img_resize, scale_range, center_crop, resize_in_yv12 = False): in_h, in_w = img.shape[:2] out_h, out_w = img_resize + if resize_in_yv12: + #to make U,V as multiple of 4 shape to properly represent in YV12 format + round_or_align4 = lambda x: ((int(x)//4)*4) + else: + round_or_align4 = lambda x: round(x) # this random scaling is w.r.t. the output size if (np.random.random() < 0.5): - resize_h = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_h)) - resize_w = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_w)) + resize_h = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_h)) + resize_w = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_w)) else: resize_h, resize_w = out_h, out_w @@ -337,11 +344,12 @@ class RandomScaleCrop(object): return out_r, out_c, out_h, out_w, resize_h, resize_w def __call__(self, images, targets): - out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop) + out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop, + resize_in_yv12 = self.resize_in_yv12) def func_img(img, img_idx): is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) - img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img) + img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12) return img def func_tgt(img, img_idx): @@ -356,11 +364,12 @@ class RandomScaleCrop(object): class RandomCropScale(object): """Crop the Image to random size and scale to the given resolution""" - def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False): + def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False, resize_in_yv12=False): self.size = size if (type(size) in (list,tuple)) else (size, size) self.crop_range = crop_range self.is_flow = is_flow self.center_crop = center_crop + self.resize_in_yv12 = resize_in_yv12 @staticmethod def get_params(img, crop_range, center_crop): @@ -386,7 +395,7 @@ class RandomCropScale(object): def func_img(img, img_idx): is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) - img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img) + img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12) return img def func_tgt(img, img_idx): @@ -514,10 +523,16 @@ class NormalizeMeanScale(object): self.mean = mean self.scale = scale - def __call__(self, images, target): - if isinstance(images, (list,tuple)): - images = [(img-self.mean)*self.scale for img in images] - else: - images = (images-self.mean)*self.scale + def __call__(self, images, target): + def func(imgs, img_idx): + if isinstance(imgs, (list,tuple)): + imgs = [(img-self.mean)*self.scale for img in imgs] + else: + imgs = (imgs-self.mean)*self.scale + # + return imgs + # + images = ImageTransformUtils.apply_to_list(func, images) \ + if isinstance(images, (list,tuple)) else func(images) return images, target diff --git a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py new file mode 100644 index 0000000..0c50cc2 --- /dev/null +++ b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py @@ -0,0 +1,253 @@ +from __future__ import division +import random +import numbers +import math +import cv2 +import numpy as np +import PIL + +from .image_transform_utils import * +from .image_transforms import * + + +def yv12_to_nv12_image(yv12=None): + image_w = yv12.shape[1] + image_h = (yv12.shape[0] * 2) // 3 + y_h = image_h + uv_h = image_h // 4 + u_w = image_w // 2 + + Y = yv12[0:image_h, 0:image_w] + V = yv12[y_h:y_h + uv_h, 0:image_w] + U = yv12[y_h + uv_h:y_h + 2 * uv_h, 0:image_w] + + UV = np.zeros((Y.shape[0] // 2, Y.shape[1]), dtype='uint8') + + # U00V00 U01V01 .... + # U10V10 U11V11 .... + # U20V20 U21V21 .... + + UV[0::2, 0::2] = U[:, 0:u_w] + UV[0::2, 1::2] = V[:, 0:u_w] + + UV[1::2, 0::2] = U[:, u_w:] + UV[1::2, 1::2] = V[:, u_w:] + Y = np.expand_dims(Y, axis=2) + UV = np.expand_dims(UV, axis=2) + + img = [Y, UV] + + test = False + if test: + op_yuv = np.zeros((yv12.shape)) + U[:, 0:u_w] = UV[0::2, 0::2, 0] + V[:, 0:u_w] = UV[0::2, 1::2, 0] + U[:, u_w:] = UV[1::2, 0::2, 0] + V[:, u_w:] = UV[1::2, 1::2, 0] + op_yuv[0:image_h, 0:image_w] = Y[:, :, 0] + op_yuv[y_h:y_h + uv_h, 0:image_w] = V + op_yuv[y_h + uv_h:y_h + 2 * uv_h, 0:image_w] = U + assert (np.array_equal(yv12, op_yuv)) + + return img + +def nv12_to_bgr_image(Y = None, UV = None, image_scale = None, image_mean = None): + image_h = Y.shape[1] + image_w = Y.shape[2] + + y_h = image_h + uv_h = image_h // 4 + u_w = image_w // 2 + + op_yuv = torch.zeros(((image_h*3)//2, image_w), device=Y.device) + U = torch.zeros((uv_h, image_w), device=Y.device) + V = torch.zeros((uv_h, image_w), device=Y.device) + U[:, 0:u_w] = UV[0, 0::2, 0::2] + V[:, 0:u_w] = UV[0, 0::2, 1::2] + U[:, u_w:] = UV[0, 1::2, 0::2] + V[:, u_w:] = UV[0, 1::2, 1::2] + op_yuv[0:image_h, 0:image_w] = Y[0, :, :] + op_yuv[y_h:y_h + uv_h, 0:image_w] = V + op_yuv[y_h + uv_h:y_h + 2 * uv_h, 0:image_w] = U + + op_yuv = op_yuv / torch.tensor(image_scale, device=Y.device) + torch.tensor(image_mean, device=Y.device) + op_yuv = op_yuv.cpu().numpy().astype('uint8') + bgr = cv2.cvtColor(op_yuv, cv2.COLOR_YUV2BGR_YV12) + return bgr + + +class RGBtoYV12(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + + def __call__(self, images, target): + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12) + h = images[img_idx].shape[0] * 2 // 3 + w = images[img_idx].shape[1] + debug_print = False + if debug_print: + print("-" * 32, "RGBtoYV12") + print("Y") + img = images[img_idx] + print(img[0:5, 0:5]) + + print("V Odd Lines") + print(img[h:h + 5, 0:5]) + + print("V Even Lines") + print(img[h:h + 5, w // 2:w // 2 + 5]) + + print("U Odd Lines") + print(img[h + h // 4:h + h // 4 + 5, 0:5]) + + print("U Even Lines") + print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5]) + + print("-" * 32) + return images, target + + +class YV12toRGB(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + + def __call__(self, images, target): + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12) + debug_print = False + if debug_print: + print("shape after YV12toRGB ", images[img_idx].shape) + print("-" * 32, "YV12toRGB") + img = images[img_idx] + print(img[0:5, 0:5, 0]) + print(img[0:5, 0:5, 1]) + print(img[0:5, 0:5, 2]) + print("-" * 32) + return images, target + + +class YV12toNV12(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + + def __call__(self, images, target): + for img_idx in range(len(images)): + [Y, UV] = yv12_to_nv12_image(yv12=images[img_idx]) + rgb = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12) + images[img_idx] = [Y, UV, rgb] if self.keep_rgb else [Y, UV] + + return images, target + + +class RGBtoNV12(object): + def __init__(self, is_flow=None, keep_rgb=False): + self.is_flow = is_flow + self.keep_rgb = keep_rgb + + def __call__(self, images, target): + + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + # current + yv12 = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12) + + # cfg:20 + # yv12 = cv2.cvtColor(images[img_idx], cv2.COLOR_BGR2YUV_YV12) + [Y, UV] = yv12_to_nv12_image(yv12=yv12) + images[img_idx] = [Y, UV, images[img_idx]] if self.keep_rgb else [Y, UV] + return images, target + + +class RGBtoNV12toRGB(object): + def __init__(self, is_flow=None): + self.is_flow = is_flow + + def __call__(self, images, target): + + for img_idx in range(len(images)): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + if not is_flow_img: + # OpenCV does not support cv2.COLOR_RGB2YUV_NV12 so instead use YV12 as + # intermediate format. Final effect should be same. + yuv = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12) + images[img_idx] = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_YV12) + return images, target + + +class RandomScaleCropYV12(RandomScaleCrop): + """Randomly zooms images up to 15% and crop them to keep same size as before.""" + + def __init__(self, img_resize, scale_range=(1.0, 2.0), is_flow=None, center_crop=False): + super().__init__(img_resize, scale_range=scale_range, is_flow=is_flow, center_crop=center_crop, resize_in_yv12=True) + + +class RandomCropScaleYV12(object): + """Crop the Image to random size and scale to the given resolution""" + + def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False): + super().__init__(self, size, crop_range=crop_range, is_flow=is_flow, center_crop=center_crop, resize_in_yv12=True) + + +class ScaleYV12(object): + def __init__(self, img_size, target_size=None, is_flow=None): + self.img_size = img_size + self.target_size = target_size if target_size else img_size + self.is_flow = is_flow + + def __call__(self, images, targets): + if self.img_size is None: + return images, targets + + def func_img(img, img_idx): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + img = ImageTransformUtils.resize_img_yv12(img, self.img_size, interpolation=-1, is_flow=is_flow_img) + return img + + def func_tgt(img, img_idx): + is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow) + img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST, + is_flow=is_flow_tgt) + return img + + images = ImageTransformUtils.apply_to_list(func_img, images) + targets = ImageTransformUtils.apply_to_list(func_tgt, targets) + + return images, targets + + + +class ScaleYV12(object): + def __init__(self, img_size, target_size=None, is_flow=None): + self.img_size = img_size + self.target_size = target_size if target_size else img_size + self.is_flow = is_flow + + def __call__(self, images, targets): + if self.img_size is None: + return images, targets + + def func_img(img, img_idx): + is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow) + img = ImageTransformUtils.resize_img_yv12(img, self.img_size, interpolation=-1, is_flow=is_flow_img) + return img + + def func_tgt(img, img_idx): + is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow) + img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST, + is_flow=is_flow_tgt) + return img + + images = ImageTransformUtils.apply_to_list(func_img, images) + targets = ImageTransformUtils.apply_to_list(func_tgt, targets) + + return images, targets + diff --git a/modules/pytorch_jacinto_ai/xnn/layers/activation.py b/modules/pytorch_jacinto_ai/xnn/layers/activation.py index 21c94c0..325764e 100644 --- a/modules/pytorch_jacinto_ai/xnn/layers/activation.py +++ b/modules/pytorch_jacinto_ai/xnn/layers/activation.py @@ -130,28 +130,9 @@ def get_fixed_pact2(inplace=False, signed=None, output_range=None): ############################################################### -class ReLUN(torch.nn.Module): - def __init__(self, inplace=False, signed=False, clips=None, **kwargs): - super().__init__() - self.clips_act = clips - self.inplace = inplace - self.signed = signed - - def forward(self, x): - y = torch.clamp(x, 0.0, self.clips_act) - return y - - def get_clips_act(self): - return 0.0, self.clips_act - - def __repr__(self): - return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, self.clips_act) - - -############################################################### -class ReLU8(ReLUN): - def __init__(self, inplace=False, signed=False, **kwargs): - super().__init__(inplace, signed, (0,8.0)) +class ReLU1(torch.nn.Hardtanh): + def __init__(self, min_val=0., max_val=1., inplace=False): + super().__init__(min_val=min_val, max_val=max_val, inplace=inplace) ############################################################### diff --git a/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py index 8f9c657..0e9b4b4 100644 --- a/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py +++ b/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py @@ -20,6 +20,24 @@ class AddBlock(torch.nn.Module): def __repr__(self): return 'AddBlock(inplace={}, signed={})'.format(self.inplace, self.signed) +# sub +class SubtractBlock(torch.nn.Module): + def __init__(self, inplace=False, signed=True, *args, **kwargs): + super().__init__() + self.inplace = inplace + self.signed = signed + + def forward(self, x): + assert isinstance(x, (list,tuple)), 'input to sub block must be a list or tuple' + y = x[0] + for i in range(1,len(x)): + y = y - x[i] + # + return y + + def __repr__(self): + return 'SubtractBlock(inplace={}, signed={})'.format(self.inplace, self.signed) + ########################################################### # mult diff --git a/modules/pytorch_jacinto_ai/xnn/layers/layer_config.py b/modules/pytorch_jacinto_ai/xnn/layers/layer_config.py index e1ced5c..1108af0 100644 --- a/modules/pytorch_jacinto_ai/xnn/layers/layer_config.py +++ b/modules/pytorch_jacinto_ai/xnn/layers/layer_config.py @@ -24,7 +24,7 @@ DefaultNorm2d = torch.nn.BatchNorm2d #SlowBatchNorm2d #Group8Norm # Default Activation # DefaultAct2d can be set to one of the activation types ############################################################### -DefaultAct2d = torch.nn.ReLU #torch.nn.ReLU6 #torch.nn.HardTanh ##pytorch_jacinto_ai.xnn.layers.ReLU8 +DefaultAct2d = torch.nn.ReLU #torch.nn.HardTanh ############################################################### # Default Convolution: torch.nn.Conv2d or ConvWS2d diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py index 31cbbc3..ba90775 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py @@ -139,7 +139,7 @@ class QuantGraphModule(HookedModule): pass elif qparams.quantize_out: if utils.is_activation(module): - if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)): + if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6)): activation_q = layers.PAct2(signed=False) elif isinstance(module, torch.nn.Hardtanh): activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val)) diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py index b1a882d..d7fbf03 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py @@ -98,7 +98,7 @@ class QuantTrainModule(QuantBaseModule): elif isinstance(m, layers.NoAct): new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations, per_channel_q=self.per_channel_q) - elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)): + elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)): new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations, per_channel_q=self.per_channel_q) else: diff --git a/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py b/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py index fc467be..eab5d3f 100644 --- a/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py +++ b/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py @@ -99,11 +99,11 @@ def check_model_data(model, data, verbose=False, ignore_names=('num_batches_trac not_matching_sizes = [k for k in model_dict.keys() if ((k in data.keys()) and (data[k].size() != model_dict[k].size()))] if missing_weights: - print_utils.print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", missing_weights) + print_utils.print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", *missing_weights, sep = "\n") if not_matching_sizes: - print_utils.print_yellow("=> The shape of the following weights did not match: ", not_matching_sizes) + print_utils.print_yellow("=> The shape of the following weights did not match: ", *not_matching_sizes, sep = "\n") if extra_weights: - print_utils.print_yellow("=> The following weights in pre-trained were not used: ", extra_weights) + print_utils.print_yellow("=> The following weights in pre-trained were not used: ", *extra_weights, sep = "\n") return missing_weights, extra_weights, not_matching_sizes diff --git a/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py index 20b0357..c12a26f 100644 --- a/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py +++ b/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py @@ -10,7 +10,7 @@ def is_normalization(module): def is_activation(module): is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh, - layers.NoAct, layers.PAct2, layers.ReLUN)) + layers.NoAct, layers.PAct2)) return is_act def is_pact2(module): @@ -80,14 +80,12 @@ def is_not_list(inp): def is_fixed_range(op): return isinstance(op, (torch.nn.ReLU6, torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.Hardtanh, \ - layers.PAct2, layers.ReLUN)) + layers.PAct2)) def get_range(op): if isinstance(op, layers.PAct2): return op.get_clips_act() - elif isinstance(op, layers.ReLUN): - return op.get_clips_act() elif isinstance(op, torch.nn.ReLU6): return 0.0, 6.0 elif isinstance(op, torch.nn.Sigmoid):