torch.nn.ReLU is the recommended activation module. removed the custom defined module...
authorManu Mathew <a0393608@ti.com>
Thu, 12 Mar 2020 10:45:36 +0000 (16:15 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 12 Mar 2020 10:52:15 +0000 (16:22 +0530)
16 files changed:
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/losses/loss_utils.py
modules/pytorch_jacinto_ai/vision/losses/norm_loss.py
modules/pytorch_jacinto_ai/vision/models/multi_input_net.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/vision/transforms/__init__.py
modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py
modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py
modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/activation.py
modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/layer_config.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py

index 605cef2d44ae3464aa7b8e98f549fcdb59211501..e9bf7322204a02ef089a7c0769d5f9d51e079ea1 100644 (file)
@@ -52,7 +52,7 @@ def get_config():
     args.model_config.freeze_decoder = False            # do not update decoder weights
     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
     args.model_config.target_input_ratio = 1            # Keep target size same as input size
-
+    args.model_config.input_nv12 = False
     args.model = None                                   # the model itself can be given from ouside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
@@ -169,6 +169,7 @@ def get_config():
     args.tensorboard_enable = True                      # en/disable of TB writing
     args.print_train_class_iou = False
     args.print_val_class_iou = False
+    args.freeze_layers = None
 
     return args
 
@@ -310,19 +311,24 @@ def main(args):
     #################################################
     pretrained_data = None
     model_surgery_quantize = False
+    pretrained_data = None
     if args.pretrained and args.pretrained != "None":
-        if isinstance(args.pretrained, dict):
-            pretrained_data = args.pretrained
-        else:
-            if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
-                pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        pretrained_data = []
+        pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained]
+        for p in pretrained_files:
+            if isinstance(p, dict):
+                p_data = p
             else:
-                pretrained_file = args.pretrained
+                if p.startswith('http://') or p.startswith('https://'):
+                    p_file = vision.datasets.utils.download_url(p, './data/downloads')
+                else:
+                    p_file = p
+                #
+                print(f'=> loading pretrained weights file: {p}')
+                p_data = torch.load(p_file)
             #
-            print(f'=> using pre-trained weights from: {args.pretrained}')
-            pretrained_data = torch.load(pretrained_file)
-        #
-        model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
+            pretrained_data.append(p_data)
+            model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False
     #
 
     #################################################
@@ -361,7 +367,9 @@ def main(args):
 
     # load pretrained model
     if pretrained_data is not None:
-        xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+        for (p_data,p_file) in zip(pretrained_data, pretrained_files):
+            print("=> using pretrained weights from: {}".format(p_file))
+            xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
     #
 
     #################################################
@@ -594,6 +602,18 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
     if args.freeze_bn:
         xnn.utils.freeze_bn(model)
     #
+    
+    #freeze layers 
+    if args.freeze_layers is not None:
+        # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
+        # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0'  'encoder.0.1' will be frozen
+        for freeze_layer_name in args.freeze_layers:
+            for name, module in model.named_modules():
+                if freeze_layer_name in name:
+                    xnn.utils.print_once("Freezing the module : {}".format(name))
+                    module.eval()
+                    for param in module.parameters():
+                        param.requires_grad = False
 
     ##########################
     for task_dx, task_losses in enumerate(args.loss_modules):
@@ -621,7 +641,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
 
         lr = scheduler.get_lr()[0]
 
-        input_list = [img.cuda() for img in inputs]
+        input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else  img.cuda() for img in inputs]
         target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
         target_sizes = [tgt.shape for tgt in target_list]
         batch_size_cur = target_sizes[0][0]
@@ -632,7 +652,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
 
         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
         # upsample output to target resolution
-        task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
+        if args.upsample_mode is not None:
+            task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
 
         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
             args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
@@ -774,7 +795,7 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
     ##########################
     for iter, (inputs, targets) in enumerate(val_loader):
         data_time.update(time.time() - end_time)
-        input_list = [j.cuda() for j in inputs]
+        input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
         target_list = [j.cuda(non_blocking=True) for j in targets]
         target_sizes = [tgt.shape for tgt in target_list]
         batch_size_cur = target_sizes[0][0]
@@ -782,8 +803,10 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
         # compute output
         task_outputs = model(input_list)
 
-        task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
-        task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
+
+        task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
+        if args.upsample_mode is not None:
+           task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
         
         if args.print_val_class_iou:
             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
@@ -864,13 +887,20 @@ def get_model_orig(model):
 
 def create_rand_inputs(args, is_cuda):
     dummy_input = []
-    for i_ch in args.model_config.input_channels:
-        x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
-        x = x.cuda() if is_cuda else x
-        dummy_input.append(x)
-    #
-    return dummy_input
+    if not args.model_config.input_nv12:
+        for i_ch in args.model_config.input_channels:
+            x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
+            x = x.cuda() if is_cuda else x
+            dummy_input.append(x)
+    else: #nv12    
+        for i_ch in args.model_config.input_channels:
+            y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1]))
+            uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1]))
+            y = y.cuda() if is_cuda else y
+            uv = uv.cuda() if is_cuda else uv
+            dummy_input.append([y,uv])
 
+    return dummy_input
 
 def count_flops(args, model):
     is_cuda = next(model.parameters()).is_cuda
@@ -922,15 +952,23 @@ def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writ
     write_prob = np.random.random()
     if (write_prob > write_freq):
         return
-
-    batch_size = input_images[0].shape[0]
+    if args.model_config.input_nv12:
+        batch_size = input_images[0][0].shape[0]
+    else:
+        batch_size = input_images[0].shape[0]
     b_index = random.randint(0, batch_size - 1)
 
     input_image = None
     for img_idx, img in enumerate(input_images):
-        input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
-        # convert back to original input range (0-255)
-        input_image = input_image / args.image_scale + args.image_mean
+        if args.model_config.input_nv12:
+            #convert NV12 to BGR for tensorboard
+            input_image = vision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index],
+                                   image_scale=args.image_scale, image_mean=args.image_mean)
+        else:
+            input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
+            # convert back to original input range (0-255)
+            input_image = input_image / args.image_scale + args.image_mean
+
         if args.is_flow and args.is_flow[0][img_idx]:
             #input corresponding to flow is assumed to have been generated by adding 128
             flow = input_image - 128
index 21a04bc33a070624ea4df11f5e6e7c7d0787b183..2df94858e8ffa9334794525d6159207aeb1d4e1a 100644 (file)
@@ -10,6 +10,8 @@ def l2_norm(x,y=None):
     diff = (x-y) if y is not None else x
     return torch.norm(diff,p=2,dim=1,keepdim=True)
 
+def l2_norm_self(x,y=None):
+    return torch.norm(x,p=2,dim=1,keepdim=True)
 
 def l1_norm(x,y=None):
     if y is not None:
index 10fb06ee42c0c4379d60350f3e61580de019a4e8..0c317ddd3364abc0e722639d5819e13c7379340b 100644 (file)
@@ -15,6 +15,13 @@ class L2NormDiff(BasicNormLossModule):
 supervised_l2_loss = l2_norm_loss = L2NormDiff
 
 
+#take absolute norm of tensor as loss instead of diff
+class L2NormSelf(BasicNormLossModule):
+    def __init__(self, sparse=False, error_fn=l2_norm_self, error_name='L2NormSelf'):
+        super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
+supervised_l2_loss_self = l2_norm_loss_self = L2NormSelf
+
+
 ########################################
 class SmoothL1Diff(BasicElementwiseLossModule):
     def __init__(self, sparse=False, error_fn=smooth_l1_loss, error_name='SmoothL1Diff'):
index 4e760586f6ebdc6fa8473b6490940d400d58e416..b411c0d2509132beddd468222fd7afad7746dbd7 100644 (file)
@@ -7,11 +7,11 @@ from .resnet import resnet50_with_model_config
 try: from .mobilenetv2_ericsun_internal import *
 except: pass
 
-try: from .mobilenetv2_gws_internal import *
+try: from .mobilenetv2_internal import *
 except: pass
 
 __all__ = ['MultiInputNet', 'mobilenet_v2_tv_mi4', 'mobilenet_v2_tv_gws_mi4', 'mobilenet_v2_ericsun_mi4',
-           'MobileNetV2TVMI4', 'ResNet50MI4']
+           'MobileNetV2TVMI4', 'MobileNetV2TVNV12MI4', 'ResNet50MI4']
 
 
 ###################################################
@@ -212,6 +212,16 @@ class MobileNetV2EricsunMI4(MultiInputNet):
 mobilenet_v2_ericsun_mi4 = MobileNetV2EricsunMI4
 
 
+# these are the real multi input blocks
+class MobileNetV2TVNV12MI4(MultiInputNet):
+    def __init__(self, model_config):
+        model_config.num_input_blocks = 4
+        model_config.fuse_channels = 24
+        super().__init__(MobileNetV2TVNV12, model_config)
+#
+mobilenet_v2_tv_nv12_mi4 = MobileNetV2TVNV12MI4
+
+# these are the real multi input blocks
 class MobileNetV2TVGWSMI4(MultiInputNet):
     def __init__(self, model_config):
         model_config.num_input_blocks = 4
index a7da59044b8ae665d69f9c2fc1d0347c7b896eba..dda072e2383353af5455e9d90325df435e5cc4fa 100644 (file)
@@ -60,7 +60,7 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
         assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
 
         x_input = x[0]
-        in_shape = x_input.shape
+        in_shape = x_input[0].shape if isinstance(x_input, (list,tuple)) else x_input.shape
 
         # high res shortcut
         shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
index fe96ee7e0d13ae9d5d8cd6a15a515cb21735f4de..1ef9803460f5db94fd218d4fdb70f64c2d897970 100644 (file)
@@ -1,2 +1,3 @@
 from .transforms import *
 from . import image_transforms
+from . import image_transforms_xv12
index 65e5012330cfd5c9d448fd8d24a057aa0e3505f1..c4a60c6a8a6748352b3256bbd3443bffafafae47 100644 (file)
@@ -113,15 +113,27 @@ class ImageTransformUtils(object):
         return img
 
     @staticmethod
-    def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False):
-        img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
+    def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False):
+        if resize_in_yv12:
+            yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
+            yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
+            img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
+        else:
+            img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
+        #
         img = ImageTransformUtils.crop(img, r, c, h, w)
         return img
 
     @staticmethod
-    def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False):
+    def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False, resize_in_yv12=False):
         img = ImageTransformUtils.crop(img, r, c, h, w)
-        img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
+        if resize_in_yv12:
+            yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
+            yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
+            img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
+        else:
+            img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
+        #
         return img
 
     @staticmethod
@@ -161,3 +173,92 @@ class ImageTransformUtils(object):
         #
         return inputs
 
+
+    #############################################################
+    # functions for nv12
+
+    @staticmethod
+    def resize_img_yv12(img, size, interpolation=-1, is_flow=False):
+        #if (len(img.shape) == 3) and (img.shape[2] == 1 or img.shape[2] == 3):
+        #    return __class__.resize_fast(img, size, interpolation)
+        debug_print = False
+        in_w = img.shape[1]
+        in_h = (img.shape[0] * 2) // 3
+        y_h = in_h
+        uv_h = in_h // 4
+        u_w = in_w // 2
+
+        Y = img[0:in_h, 0:in_w]
+        V = img[y_h:y_h + uv_h, 0:in_w]
+        #print(V[0:2,0:8])
+        #print(V[0:2, u_w:u_w+8])
+        V = V.reshape(V.shape[0]*2, -1)
+        #print(V[0:2, 0:8])
+        #print(V[0:2, u_w:u_w + 8])
+        U = img[y_h + uv_h:y_h + 2 * uv_h, 0:in_w]
+        U = U.reshape(U.shape[0] * 2, -1)
+
+        out_h, out_w = size
+        if interpolation < 0:
+            interpolation = cv2.INTER_AREA if ((out_h < in_h) or (out_w < in_w)) else cv2.INTER_LINEAR
+
+        Y = cv2.resize(Y, (out_w, out_h), interpolation=interpolation)
+        U = cv2.resize(U, (out_w//2, out_h//2), interpolation=interpolation)
+        V = cv2.resize(V, (out_w//2, out_h//2), interpolation=interpolation)
+
+        img = np.zeros((out_h*3//2, out_w), dtype='uint8')
+        op_uv_h = out_h // 4
+
+        img[0:out_h, 0:out_w] = Y[:, :]
+        #print(V[0:2,0:8])
+        V = V.reshape(V.shape[0] // 2, -1)
+        #print(V[0:1,0:8])
+        #print(V[0:1, op_u_w:op_u_w+8])
+        img[out_h:out_h + op_uv_h, 0:out_w] = V
+        U = U.reshape(U.shape[0] // 2, -1)
+        img[out_h + op_uv_h:out_h + 2 * op_uv_h, 0:out_w] = U
+
+        if debug_print:
+            h = img.shape[0] * 2 // 3
+            w = img.shape[1]
+            print("-" * 32, "Resize in YV12")
+            print("Y")
+            print(img[0:5, 0:5])
+
+            print("V Odd Lines")
+            print(img[h:h + 5, 0:5])
+
+            print("V Even Lines")
+            print(img[h:h + 5, w // 2:w // 2 + 5])
+
+            print("U Odd Lines")
+            print(img[h + h // 4:h + h // 4 + 5, 0:5])
+
+            print("U Even Lines")
+            print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5])
+
+            print("-" * 32)
+
+        if is_flow:
+            ratio_h = out_h / in_h
+            ratio_w = out_w / in_w
+            img = ImageTransformUtils.scale_flow(img, ratio_w, ratio_h)
+
+        return img
+
+
+    # @staticmethod
+    # def resize_and_crop_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False):
+    #     yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
+    #     yv12 = ImageTransformUtils.resize_img_yv12(yv12, size, interpolation, is_flow)
+    #     img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGb_YV12)
+    #     img = ImageTransformUtils.crop(img, r, c, h, w)
+    #     return img
+    #
+    # @staticmethod
+    # def crop_and_resize_yv12(img, r, c, h, w, size, interpolation=-1, is_flow=False):
+    #     img = ImageTransformUtils.crop(img, r, c, h, w)
+    #     yv12 = cv2.cvtColor(img, cv2.COLOR_RGB2YUV_YV12)
+    #     yv12 = ImageTransformUtils.resize_img_yv12(img, size, interpolation, is_flow)
+    #     img = cv2.cvtColor(yv12, cv2.COLOR_YUV2RGB_YV12)
+    #     return img
\ No newline at end of file
index 9b46e5c8b89549688358027b4dfc1b300208a032..fbf9ac53908070fa0b25f62c25e2a645ab8105a5 100644 (file)
@@ -48,9 +48,10 @@ class AlignImages(object):
 class ConvertToTensor(object):
     """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
     def __call__(self, images, targets):
-        def func(img, img_idx):
-            img = ImageTransformUtils.array_to_tensor(img)
-            return img
+        def func(imgs, img_idx):
+            imgs = [ImageTransformUtils.array_to_tensor(img_plane) for img_plane in imgs] \
+                if isinstance(imgs, list) else ImageTransformUtils.array_to_tensor(imgs)
+            return imgs
 
         images = ImageTransformUtils.apply_to_list(func, images)
         targets = ImageTransformUtils.apply_to_list(func, targets)
@@ -314,20 +315,26 @@ class RandomColor2Gray(object):
 
 class RandomScaleCrop(object):
     """Randomly zooms images up to 15% and crop them to keep same size as before."""
-    def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False):
+    def __init__(self, img_resize, scale_range=(1.0,2.0), is_flow=None, center_crop=False, resize_in_yv12=False):
         self.img_resize = img_resize
         self.scale_range = scale_range
         self.is_flow = is_flow
         self.center_crop = center_crop
+        self.resize_in_yv12 = resize_in_yv12
 
     @staticmethod
-    def get_params(img, img_resize, scale_range, center_crop):
+    def get_params(img, img_resize, scale_range, center_crop, resize_in_yv12 = False):
         in_h, in_w = img.shape[:2]
         out_h, out_w = img_resize
+        if resize_in_yv12:
+            #to make U,V as multiple of 4 shape to properly represent in YV12 format
+            round_or_align4 = lambda x: ((int(x)//4)*4)
+        else:
+            round_or_align4 = lambda x: round(x)
         # this random scaling is w.r.t. the output size
         if (np.random.random() < 0.5):
-            resize_h = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_h))
-            resize_w = int(round(np.random.uniform(scale_range[0], scale_range[1]) * out_w))
+            resize_h = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_h))
+            resize_w = int(round_or_align4(np.random.uniform(scale_range[0], scale_range[1]) * out_w))
         else:
             resize_h, resize_w = out_h, out_w
 
@@ -337,11 +344,12 @@ class RandomScaleCrop(object):
         return out_r, out_c, out_h, out_w, resize_h, resize_w
 
     def __call__(self, images, targets):
-        out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop)
+        out_r, out_c, out_h, out_w, resize_h, resize_w = self.get_params(images[0], self.img_resize, self.scale_range, self.center_crop,
+                                                                         resize_in_yv12 = self.resize_in_yv12)
 
         def func_img(img, img_idx):
             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
-            img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img)
+            img = ImageTransformUtils.resize_and_crop(img, out_r, out_c, out_h, out_w, (resize_h, resize_w), is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12)
             return img
 
         def func_tgt(img, img_idx):
@@ -356,11 +364,12 @@ class RandomScaleCrop(object):
 
 class RandomCropScale(object):
     """Crop the Image to random size and scale to the given resolution"""
-    def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False):
+    def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False, resize_in_yv12=False):
         self.size = size if (type(size) in (list,tuple)) else (size, size)
         self.crop_range = crop_range
         self.is_flow = is_flow
         self.center_crop = center_crop
+        self.resize_in_yv12 = resize_in_yv12
 
     @staticmethod
     def get_params(img, crop_range, center_crop):
@@ -386,7 +395,7 @@ class RandomCropScale(object):
 
         def func_img(img, img_idx):
             is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
-            img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img)
+            img = ImageTransformUtils.crop_and_resize(img, out_r, out_c, out_h, out_w, self.size, is_flow=is_flow_img, resize_in_yv12=self.resize_in_yv12)
             return img
 
         def func_tgt(img, img_idx):
@@ -514,10 +523,16 @@ class NormalizeMeanScale(object):
         self.mean = mean
         self.scale = scale
 
-    def __call__(self, images, target):
-        if isinstance(images, (list,tuple)):
-            images = [(img-self.mean)*self.scale for img in images]
-        else:
-            images = (images-self.mean)*self.scale
 
+    def __call__(self, images, target):
+        def func(imgs, img_idx):
+            if isinstance(imgs, (list,tuple)):
+                imgs = [(img-self.mean)*self.scale for img in imgs]
+            else:
+                imgs = (imgs-self.mean)*self.scale
+            #
+            return imgs
+        #
+        images = ImageTransformUtils.apply_to_list(func, images) \
+            if isinstance(images, (list,tuple)) else func(images)
         return images, target
diff --git a/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py b/modules/pytorch_jacinto_ai/vision/transforms/image_transforms_xv12.py
new file mode 100644 (file)
index 0000000..0c50cc2
--- /dev/null
@@ -0,0 +1,253 @@
+from __future__ import division
+import random
+import numbers
+import math
+import cv2
+import numpy as np
+import PIL
+
+from .image_transform_utils import *
+from .image_transforms import *
+
+
+def yv12_to_nv12_image(yv12=None):
+    image_w = yv12.shape[1]
+    image_h = (yv12.shape[0] * 2) // 3
+    y_h = image_h
+    uv_h = image_h // 4
+    u_w = image_w // 2
+
+    Y = yv12[0:image_h, 0:image_w]
+    V = yv12[y_h:y_h + uv_h, 0:image_w]
+    U = yv12[y_h + uv_h:y_h + 2 * uv_h, 0:image_w]
+
+    UV = np.zeros((Y.shape[0] // 2, Y.shape[1]), dtype='uint8')
+
+    # U00V00   U01V01 ....
+    # U10V10   U11V11 ....
+    # U20V20   U21V21 ....
+
+    UV[0::2, 0::2] = U[:, 0:u_w]
+    UV[0::2, 1::2] = V[:, 0:u_w]
+
+    UV[1::2, 0::2] = U[:, u_w:]
+    UV[1::2, 1::2] = V[:, u_w:]
+    Y = np.expand_dims(Y, axis=2)
+    UV = np.expand_dims(UV, axis=2)
+
+    img = [Y, UV]
+
+    test = False
+    if test:
+        op_yuv = np.zeros((yv12.shape))
+        U[:, 0:u_w] = UV[0::2, 0::2, 0]
+        V[:, 0:u_w] = UV[0::2, 1::2, 0]
+        U[:, u_w:] = UV[1::2, 0::2, 0]
+        V[:, u_w:] = UV[1::2, 1::2, 0]
+        op_yuv[0:image_h, 0:image_w] = Y[:, :, 0]
+        op_yuv[y_h:y_h + uv_h, 0:image_w] = V
+        op_yuv[y_h + uv_h:y_h + 2 * uv_h, 0:image_w] = U
+        assert (np.array_equal(yv12, op_yuv))
+
+    return img
+
+def nv12_to_bgr_image(Y = None, UV = None, image_scale = None, image_mean = None):
+    image_h = Y.shape[1]
+    image_w = Y.shape[2]
+
+    y_h = image_h
+    uv_h = image_h // 4
+    u_w = image_w // 2
+
+    op_yuv = torch.zeros(((image_h*3)//2, image_w), device=Y.device)
+    U = torch.zeros((uv_h, image_w), device=Y.device)
+    V = torch.zeros((uv_h, image_w), device=Y.device)
+    U[:, 0:u_w] = UV[0, 0::2, 0::2]
+    V[:, 0:u_w] = UV[0, 0::2, 1::2]
+    U[:, u_w:] = UV[0, 1::2, 0::2]
+    V[:, u_w:] = UV[0, 1::2, 1::2]
+    op_yuv[0:image_h, 0:image_w] = Y[0, :, :]
+    op_yuv[y_h:y_h + uv_h, 0:image_w] = V
+    op_yuv[y_h + uv_h:y_h + 2 * uv_h, 0:image_w] = U
+
+    op_yuv = op_yuv / torch.tensor(image_scale, device=Y.device) + torch.tensor(image_mean, device=Y.device)
+    op_yuv = op_yuv.cpu().numpy().astype('uint8')
+    bgr = cv2.cvtColor(op_yuv, cv2.COLOR_YUV2BGR_YV12)
+    return bgr
+
+
+class RGBtoYV12(object):
+    def __init__(self, is_flow=None, keep_rgb=False):
+        self.is_flow = is_flow
+        self.keep_rgb = keep_rgb
+
+    def __call__(self, images, target):
+        for img_idx in range(len(images)):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            if not is_flow_img:
+                images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12)
+                h = images[img_idx].shape[0] * 2 // 3
+                w = images[img_idx].shape[1]
+                debug_print = False
+                if debug_print:
+                    print("-" * 32, "RGBtoYV12")
+                    print("Y")
+                    img = images[img_idx]
+                    print(img[0:5, 0:5])
+
+                    print("V Odd Lines")
+                    print(img[h:h + 5, 0:5])
+
+                    print("V Even Lines")
+                    print(img[h:h + 5, w // 2:w // 2 + 5])
+
+                    print("U Odd Lines")
+                    print(img[h + h // 4:h + h // 4 + 5, 0:5])
+
+                    print("U Even Lines")
+                    print(img[h + h // 4:h + h // 4 + 5, w // 2:w // 2 + 5])
+
+                    print("-" * 32)
+        return images, target
+
+
+class YV12toRGB(object):
+    def __init__(self, is_flow=None, keep_rgb=False):
+        self.is_flow = is_flow
+        self.keep_rgb = keep_rgb
+
+    def __call__(self, images, target):
+        for img_idx in range(len(images)):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            if not is_flow_img:
+                images[img_idx] = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12)
+                debug_print = False
+                if debug_print:
+                    print("shape after YV12toRGB ", images[img_idx].shape)
+                    print("-" * 32, "YV12toRGB")
+                    img = images[img_idx]
+                    print(img[0:5, 0:5, 0])
+                    print(img[0:5, 0:5, 1])
+                    print(img[0:5, 0:5, 2])
+                    print("-" * 32)
+        return images, target
+
+
+class YV12toNV12(object):
+    def __init__(self, is_flow=None, keep_rgb=False):
+        self.is_flow = is_flow
+        self.keep_rgb = keep_rgb
+
+    def __call__(self, images, target):
+        for img_idx in range(len(images)):
+            [Y, UV] = yv12_to_nv12_image(yv12=images[img_idx])
+            rgb = cv2.cvtColor(images[img_idx], cv2.COLOR_YUV2RGB_YV12)
+            images[img_idx] = [Y, UV, rgb] if self.keep_rgb else [Y, UV]
+
+        return images, target
+
+
+class RGBtoNV12(object):
+    def __init__(self, is_flow=None, keep_rgb=False):
+        self.is_flow = is_flow
+        self.keep_rgb = keep_rgb
+
+    def __call__(self, images, target):
+
+        for img_idx in range(len(images)):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            if not is_flow_img:
+                # current
+                yv12 = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12)
+
+                # cfg:20
+                # yv12 = cv2.cvtColor(images[img_idx], cv2.COLOR_BGR2YUV_YV12)
+                [Y, UV] = yv12_to_nv12_image(yv12=yv12)
+                images[img_idx] = [Y, UV, images[img_idx]] if self.keep_rgb else [Y, UV]
+        return images, target
+
+
+class RGBtoNV12toRGB(object):
+    def __init__(self, is_flow=None):
+        self.is_flow = is_flow
+
+    def __call__(self, images, target):
+
+        for img_idx in range(len(images)):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            if not is_flow_img:
+                # OpenCV does not support cv2.COLOR_RGB2YUV_NV12 so instead use YV12 as
+                # intermediate format. Final effect should be same.
+                yuv = cv2.cvtColor(images[img_idx], cv2.COLOR_RGB2YUV_YV12)
+                images[img_idx] = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_YV12)
+        return images, target
+
+
+class RandomScaleCropYV12(RandomScaleCrop):
+    """Randomly zooms images up to 15% and crop them to keep same size as before."""
+
+    def __init__(self, img_resize, scale_range=(1.0, 2.0), is_flow=None, center_crop=False):
+        super().__init__(img_resize, scale_range=scale_range, is_flow=is_flow, center_crop=center_crop, resize_in_yv12=True)
+
+
+class RandomCropScaleYV12(object):
+    """Crop the Image to random size and scale to the given resolution"""
+
+    def __init__(self, size, crop_range=(0.08, 1.0), is_flow=None, center_crop=False):
+        super().__init__(self, size, crop_range=crop_range, is_flow=is_flow, center_crop=center_crop, resize_in_yv12=True)
+
+
+class ScaleYV12(object):
+    def __init__(self, img_size, target_size=None, is_flow=None):
+        self.img_size = img_size
+        self.target_size = target_size if target_size else img_size
+        self.is_flow = is_flow
+
+    def __call__(self, images, targets):
+        if self.img_size is None:
+            return images, targets
+
+        def func_img(img, img_idx):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            img = ImageTransformUtils.resize_img_yv12(img, self.img_size, interpolation=-1, is_flow=is_flow_img)
+            return img
+
+        def func_tgt(img, img_idx):
+            is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
+            img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST,
+                                                 is_flow=is_flow_tgt)
+            return img
+
+        images = ImageTransformUtils.apply_to_list(func_img, images)
+        targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
+
+        return images, targets
+
+
+
+class ScaleYV12(object):
+    def __init__(self, img_size, target_size=None, is_flow=None):
+        self.img_size = img_size
+        self.target_size = target_size if target_size else img_size
+        self.is_flow = is_flow
+
+    def __call__(self, images, targets):
+        if self.img_size is None:
+            return images, targets
+
+        def func_img(img, img_idx):
+            is_flow_img = (self.is_flow[0][img_idx] if self.is_flow else self.is_flow)
+            img = ImageTransformUtils.resize_img_yv12(img, self.img_size, interpolation=-1, is_flow=is_flow_img)
+            return img
+
+        def func_tgt(img, img_idx):
+            is_flow_tgt = (self.is_flow[1][img_idx] if self.is_flow else self.is_flow)
+            img = ImageTransformUtils.resize_img(img, self.target_size, interpolation=cv2.INTER_NEAREST,
+                                                 is_flow=is_flow_tgt)
+            return img
+
+        images = ImageTransformUtils.apply_to_list(func_img, images)
+        targets = ImageTransformUtils.apply_to_list(func_tgt, targets)
+
+        return images, targets
+
index 21c94c01668b3db7d19928e3f300e67cb04d1229..325764e0caaea4cc4fb3dbe87ba9283f709fb41a 100644 (file)
@@ -130,28 +130,9 @@ def get_fixed_pact2(inplace=False, signed=None, output_range=None):
 
 
 ###############################################################
-class ReLUN(torch.nn.Module):
-    def __init__(self, inplace=False, signed=False, clips=None, **kwargs):
-        super().__init__()
-        self.clips_act = clips
-        self.inplace = inplace
-        self.signed = signed
-
-    def forward(self, x):
-        y = torch.clamp(x, 0.0, self.clips_act)
-        return y
-
-    def get_clips_act(self):
-        return 0.0, self.clips_act
-
-    def __repr__(self):
-        return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, self.clips_act)
-
-
-###############################################################
-class ReLU8(ReLUN):
-    def __init__(self, inplace=False, signed=False, **kwargs):
-        super().__init__(inplace, signed, (0,8.0))
+class ReLU1(torch.nn.Hardtanh):
+    def __init__(self, min_val=0., max_val=1., inplace=False):
+        super().__init__(min_val=min_val, max_val=max_val, inplace=inplace)
 
 
 ###############################################################
index 8f9c6577943fb879091b8251a5648deb95147c95..0e9b4b40c5a6905c7d412cfaa90ed1ada1bcabab 100644 (file)
@@ -20,6 +20,24 @@ class AddBlock(torch.nn.Module):
     def __repr__(self):
         return 'AddBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
 
+# sub
+class SubtractBlock(torch.nn.Module):
+    def __init__(self, inplace=False, signed=True, *args, **kwargs):
+        super().__init__()
+        self.inplace = inplace
+        self.signed = signed
+
+    def forward(self, x):
+        assert isinstance(x, (list,tuple)), 'input to sub block must be a list or tuple'
+        y = x[0]
+        for i in range(1,len(x)):
+            y = y - x[i]
+        #
+        return y
+
+    def __repr__(self):
+        return 'SubtractBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
+
 
 ###########################################################
 # mult
index e1ced5c7df02f337bd30294834afd899eeccb9f8..1108af0f1cbeca308ac2cfdbc1495a6a76a1858b 100644 (file)
@@ -24,7 +24,7 @@ DefaultNorm2d = torch.nn.BatchNorm2d #SlowBatchNorm2d #Group8Norm
 # Default Activation
 # DefaultAct2d can be set to one of the activation types
 ###############################################################
-DefaultAct2d = torch.nn.ReLU #torch.nn.ReLU6 #torch.nn.HardTanh ##pytorch_jacinto_ai.xnn.layers.ReLU8
+DefaultAct2d = torch.nn.ReLU #torch.nn.HardTanh
 
 ###############################################################
 # Default Convolution: torch.nn.Conv2d or ConvWS2d
index 31cbbc35452193f1b9551d051eacb1335779bd02..ba907754d42dc76c0e2ce80ce0b60f924d262f19 100644 (file)
@@ -139,7 +139,7 @@ class QuantGraphModule(HookedModule):
                 pass
             elif qparams.quantize_out:
                 if utils.is_activation(module):
-                    if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)):
+                    if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6)):
                         activation_q = layers.PAct2(signed=False)
                     elif isinstance(module, torch.nn.Hardtanh):
                         activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
index b1a882d07534ff19fcbc606038651f956a97bcbd..d7fbf0386efa99d27a16ffe7916440a6a6d493bf 100644 (file)
@@ -98,7 +98,7 @@ class QuantTrainModule(QuantBaseModule):
                 elif isinstance(m, layers.NoAct):
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                                              per_channel_q=self.per_channel_q)
-                elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)):
+                elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                                              per_channel_q=self.per_channel_q)
                 else:
index fc467be7646b0db4df1bd87c49e1380aa0fc99fd..eab5d3f043bb1c2703b406323ed880be1a5c3102 100644 (file)
@@ -99,11 +99,11 @@ def check_model_data(model, data, verbose=False, ignore_names=('num_batches_trac
     not_matching_sizes = [k for k in model_dict.keys() if ((k in data.keys()) and (data[k].size() != model_dict[k].size()))]
 
     if missing_weights:
-        print_utils.print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", missing_weights)
+        print_utils.print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", *missing_weights, sep = "\n")
     if not_matching_sizes:
-        print_utils.print_yellow("=> The shape of the following weights did not match: ", not_matching_sizes)
+        print_utils.print_yellow("=> The shape of the following weights did not match: ", *not_matching_sizes, sep = "\n")
     if extra_weights:
-        print_utils.print_yellow("=> The following weights in pre-trained were not used: ", extra_weights)
+        print_utils.print_yellow("=> The following weights in pre-trained were not used: ", *extra_weights, sep = "\n")
                 
     return missing_weights, extra_weights, not_matching_sizes
 
index 20b03577bde09060c3966a57538aac05c2ad12cb..c12a26fdf73109f70d14f55452aa8f797a43ca16 100644 (file)
@@ -10,7 +10,7 @@ def is_normalization(module):
 
 def is_activation(module):
     is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
-                                 layers.NoAct, layers.PAct2, layers.ReLUN))
+                                 layers.NoAct, layers.PAct2))
     return is_act
 
 def is_pact2(module):
@@ -80,14 +80,12 @@ def is_not_list(inp):
 
 def is_fixed_range(op):
     return isinstance(op, (torch.nn.ReLU6, torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.Hardtanh, \
-                           layers.PAct2, layers.ReLUN))
+                           layers.PAct2))
 
 
 def get_range(op):
     if isinstance(op, layers.PAct2):
         return op.get_clips_act()
-    elif isinstance(op, layers.ReLUN):
-        return op.get_clips_act()
     elif isinstance(op, torch.nn.ReLU6):
         return 0.0, 6.0
     elif isinstance(op, torch.nn.Sigmoid):