renamed the low complexity pixel2pixel models with suffix "lite"
authorManu Mathew <a0393608@ti.com>
Fri, 21 Feb 2020 06:13:05 +0000 (11:43 +0530)
committerManu Mathew <a0393608@ti.com>
Fri, 21 Feb 2020 06:24:18 +0000 (11:54 +0530)
24 files changed:
docs/Semantic_Segmentation.md
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/__init__.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpnlite_pixel2pixel.py [moved from modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py with 83% similarity]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unetlite_pixel2pixel.py [moved from modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unet_pixel2pixel.py with 83% similarity]
modules/pytorch_jacinto_ai/xnn/layers/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/deconv_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py [deleted file]
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py
modules/pytorch_jacinto_ai/xnn/utils/count_flops.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
modules/pytorch_jacinto_ai/xnn/utils/weights_utils.py
run_depth.sh
run_segmentation.sh
scripts/train_depth_main.py
scripts/train_segmentation_main.py

index 5d9d048c6f1bf8f6e93b74f9ecd554f314913d50..fa441cbe42f2320f54d74c21215f3a02eacaff58 100644 (file)
@@ -9,9 +9,9 @@ A set of common example model configurations are defined for all pixel to pixel
 
 Whether to use multiple inputs or how many decoders to use are fully configurable. This framework is also flexible to add different model architectures or backbone networks. Some of the model configurations are currently available are:<br>
 * **deeplabv3lite_mobilenetv2_tv**: (default) This model is mostly similar to the DeepLabV3+ model [[6]] using MobileNetV2 backbone. The difference with DeepLabV3+ is that we removed the convolutions after the shortcut and kep one set of depthwise separable convolutions to generate the prediction. The ASPP module that we used is a lite-weight variant with depthwise separable convolutions (DWASPP). We found that this reduces complexity without sacrificing accuracy. Due to this we call this model DeepLabV3+(Lite) or simply  DeepLabV3Lite. (Note: The suffix "_tv" is used to indicate that our backbone model is from torchvision)<br> 
-* **fpn_pixel2pixel_aspp_mobilenetv2_tv**: This is similar to Feature Pyramid Network [[3]], but adapted for pixel2pixel tasks. We stop the decoder at a stride of 4 and then upsample to the final resolution from there. We also use DWASPP module to improve the receptive field. We call this model FPNPixel2Pixel. 
-* **fpn_pixel2pixel_aspp_mobilenetv2_tv_fd**: This is also FPN, but with a larger encoder stride(64). This is a low complexity model (using Fast Downsampling Strategy [8]) that can be used with higher resolutions.
-* **fpn_pixel2pixel_aspp_resnet50**: Feature Pyramid Network (FPN) based pixel2pixel using ResNet50 backbone with DWASPP.
+* **fpnlite_pixel2pixel_aspp_mobilenetv2_tv**: This is similar to Feature Pyramid Network [[3]], but adapted for pixel2pixel tasks. We stop the decoder at a stride of 4 and then upsample to the final resolution from there. We also use DWASPP module to improve the receptive field. We call this model FPNPixel2Pixel.
+* **fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd**: This is also FPN, but with a larger encoder stride(64). This is a low complexity model (using Fast Downsampling Strategy [8]) that can be used with higher resolutions.
+* **fpnlite_pixel2pixel_aspp_resnet50**: Feature Pyramid Network (FPN) based pixel2pixel using ResNet50 backbone with DWASPP.
 
 ## Datasets: Cityscapes Dataset 
 
@@ -57,7 +57,7 @@ Whether to use multiple inputs or how many decoders to use are fully configurabl
 
 * Train FPNPixel2Pixel model at 1536x768 resolution (use 1024x512 crop to reduce memory usage):<br>
     ```
-    python ./scripts/train_segmentation_main.py --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
+    python ./scripts/train_segmentation_main.py --model_name fpnlite_pixel2pixel_aspp_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
     ```
  
 * **VOC Segmentation Training** can be done as follows:<br>
@@ -81,20 +81,20 @@ Inference can be done as follows (fill in the path to the pretrained model):<br>
 
 ### Cityscapes Segmentation
 
-|Dataset    |Mode Architecture         |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                  |
-|---------  |----------                |-----------    |-------------- |-----------|--------             |----------|----------------------------------------  |
-|Cityscapes |FPNPixel2Pixel with DWASPP|FD-MobileNetV2 |64             |768x384    |0.99                 |62.43     |fpn_pixel2pixel_aspp_mobilenetv2_tv_fd    |
-|Cityscapes |UNet with DWASPP          |MobileNetV2    |32             |768x384    |**2.20**             |**68.94** |**unet_pixel2pixel_aspp_mobilenetv2_tv**  |
-|Cityscapes |DeepLabV3Lite with DWASPP |MobileNetV2    |16             |768x384    |**3.54**             |**69.13** |**deeplabv3lite_mobilenetv2_tv**          |
-|Cityscapes |FPNPixel2Pixel            |MobileNetV2    |32(\*2\*2)     |768x384    |3.66                 |70.30     |fpn_pixel2pixel_mobilenetv2_tv            |
-|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2    |32             |768x384    |3.84                 |70.39     |fpn_pixel2pixel_aspp_mobilenetv2_tv       |
-|Cityscapes |FPNPixel2Pixel            |FD-MobileNetV2 |64(\*2\*2)     |1536x768   |3.85                 |69.82     |fpn_pixel2pixel_mobilenetv2_tv_fd         |
-|Cityscapes |FPNPixel2Pixel with DWASPP|FD-MobileNetV2 |64             |1536x768   |**3.96**             |**71.28** |**fpn_pixel2pixel_aspp_mobilenetv2_tv_fd**|
-|Cityscapes |FPNPixel2Pixel with DWASPP|FD-MobileNetV2 |64             |2048x1024  |7.03                 |72.67     |fpn_pixel2pixel_aspp_mobilenetv2_tv_fd    |
-|Cityscapes |DeepLabV3Lite with DWASPP |MobileNetV2    |16             |1536x768   |14.48                |73.59     |deeplabv3lite_mobilenetv2_tv              |
-|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2    |32             |1536x768   |**15.37**            |**74.98** |**fpn_pixel2pixel_aspp_mobilenetv2_tv**   |
-|Cityscapes |FPNPixel2Pixel with DWASPP|FD-ResNet50    |64             |1536x768   |30.91                |-         |fpn_pixel2pixel_aspp_resnet50_fd          |
-|Cityscapes |FPNPixel2Pixel with DWASPP|ResNet50       |32             |1536x768   |114.42               |-         |fpn_pixel2pixel_aspp_resnet50             |
+|Dataset    |Mode Architecture             |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                      |
+|---------  |----------                    |-----------    |-------------- |-----------|--------             |----------|----------------------------------------      |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-MobileNetV2 |64             |768x384    |0.99                 |62.43     |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd    |
+|Cityscapes |UNetLite with DWASPP          |MobileNetV2    |32             |768x384    |**2.20**             |**68.94** |**unetlite_pixel2pixel_aspp_mobilenetv2_tv**  |
+|Cityscapes |DeepLabV3Lite with DWASPP     |MobileNetV2    |16             |768x384    |**3.54**             |**69.13** |**deeplabv3lite_mobilenetv2_tv**              |
+|Cityscapes |FPNLitePixel2Pixel            |MobileNetV2    |32(\*2\*2)     |768x384    |3.66                 |70.30     |fpnlite_pixel2pixel_mobilenetv2_tv            |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|MobileNetV2    |32             |768x384    |3.84                 |70.39     |fpnlite_pixel2pixel_aspp_mobilenetv2_tv       |
+|Cityscapes |FPNLitePixel2Pixel            |FD-MobileNetV2 |64(\*2\*2)     |1536x768   |3.85                 |69.82     |fpnlite_pixel2pixel_mobilenetv2_tv_fd         |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-MobileNetV2 |64             |1536x768   |**3.96**             |**71.28** |**fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd**|
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-MobileNetV2 |64             |2048x1024  |7.03                 |72.67     |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd    |
+|Cityscapes |DeepLabV3Lite with DWASPP     |MobileNetV2    |16             |1536x768   |14.48                |73.59     |deeplabv3lite_mobilenetv2_tv                  |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|MobileNetV2    |32             |1536x768   |**15.37**            |**74.98** |**fpnlite_pixel2pixel_aspp_mobilenetv2_tv**   |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-ResNet50    |64             |1536x768   |30.91                |-         |fpnlite_pixel2pixel_aspp_resnet50_fd          |
+|Cityscapes |FPNLitePixel2Pixel with DWASPP|ResNet50       |32             |1536x768   |114.42               |-         |fpnlite_pixel2pixel_aspp_resnet50             |
 
 |Dataset    |Mode Architecture         |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                  |
 |---------  |----------                |-----------    |-------------- |-----------|--------             |----------|----------------------------------------  |
@@ -103,7 +103,8 @@ Inference can be done as follows (fill in the path to the pretrained model):<br>
 |Cityscapes |DeepLabV3Plus[[6,7]]      |MobileNetV2    |16             |           |21.27                |70.71     |-                                         |
 |Cityscapes |DeepLabV3Plus[[6,7]]      |Xception65     |16             |           |418.64               |78.79     |-                                         |
 
-Notes: 
+Notes:
+- The suffix **'Lite'** in the model names indicates complexity optimizations in the Decoder part of the model - especially the use of DepthWise Separable Convolutions instead of regular convolutions.
 - (\*2\*2) in the above table represents two additional Depthwise Separable Convolutions with strides (at the end of the backbone encoder). 
 - FD-MobileNetV2 Backbone uses a stride of 64 (this is used in some rows of the above table) and is achieved by Fast Downsampling Strategy [8]
 
index 4c62bac0d2aadfe753ff08ebe901ee9154f07da5..e5dadf39f3e017bf08d78c68c279816ae183a276 100644 (file)
@@ -31,6 +31,7 @@ def get_config():
     args.dataset_config = xnn.utils.ConfigNode()
     args.model_config.num_tiles_x = int(1)
     args.model_config.num_tiles_y = int(1)
+    args.model_config.en_make_divisible_by8 = True
 
     args.model_config.input_channels = 3                # num input channels
 
@@ -103,7 +104,7 @@ def get_config():
     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
 
     args.freeze_bn = False                              # freeze the statistics of bn
-
+    args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
     return args
 
 
@@ -138,6 +139,19 @@ def main(args):
     if not os.path.exists(save_path):
         os.makedirs(save_path)
 
+    if args.save_mod_files:
+        #store all the files after the last commit.
+        mod_files_path = save_path+'/mod_files'
+        os.makedirs(mod_files_path)
+        
+        cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+        #stoe last commit id. 
+        cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
+        print("cmd:", cmd)    
+        os.system(cmd)
     #################################################
     if args.logger is None:
         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
index 0f13921bddcb43226dd9b1f111ba1c2c18434db2..c9ee147a247ca731d424bc8053f784f4ec0637ad 100644 (file)
@@ -21,6 +21,7 @@ def get_config():
     model_config.linear_dw = False
     model_config.layer_setting = None
     model_config.classifier_type = torch.nn.Linear
+    model_config.en_make_divisible_by8 = True
     return model_config
 
 model_urls = {
@@ -66,13 +67,14 @@ class MobileNetV1Base(torch.nn.Module):
         kernel_size = self.model_config.kernel_size
 
         # building first layer
-        output_channels = make_divisible_by8(self.model_config.layer_setting[0][1] * width_mult)
+        output_channels = int(self.model_config.layer_setting[0][1] * width_mult)
+        output_channels = make_divisible_by8(output_channels) if model_config.en_make_divisible_by8 else output_channels
         features = [xnn.layers.ConvNormAct2d(3, output_channels, kernel_size=kernel_size, stride=s0, activation=activation)]
         channels = output_channels
 
         # building inverted residual blocks
         for t, c, n, s in self.model_config.layer_setting[1:]:
-            output_channels = make_divisible_by8(c * width_mult)
+            output_channels = make_divisible_by8(c * width_mult) if model_config.en_make_divisible_by8 else int(c * width_mult)
             for i in range(n):
                 stride = s if i == 0 else 1
                 block = BlockBuilder(channels, output_channels, stride=stride, kernel_size=kernel_size, activation=(activation,activation))
index a0aa684634795373bbe2d23d41b2e0c247a0cb80..a99f34275c8b44cc6fdc400fddbddc00a7c3b93b 100644 (file)
@@ -1,6 +1,6 @@
 from .deeplabv3lite import *
-from .fpn_pixel2pixel import *
-from .unet_pixel2pixel import *
+from .fpnlite_pixel2pixel import *
+from .unetlite_pixel2pixel import *
 
 try: from .deeplabv3lite_internal import *
 except: pass
index 4bf0c0b23b6e76908e77bd51c7c5cc149218334a..e6942fd481d1d843d5f4c11305013849b2b2642a 100644 (file)
@@ -48,17 +48,11 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
 
         self.cat = xnn.layers.CatBlock()
 
-        # prediction
+        # add prediction & upsample modules
         if self.model_config.final_prediction:
-            ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
-            final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) if (model_config.output_range is not None) else False
-            self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False), activation=(False,final_activation), groups=1)
-            if self.model_config.final_upsample:
-                upstride2 = model_config.shortcut_strides[0]
-                # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
-                # useful if upsampling factors other than 4 and 2 are not supported.
-                self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
-            #
+            add_lite_prediction_modules(self, model_config, merged_channels, module_names=('pred','upsample2'))
+        #
+
 
     # the upsampling is using functional form to support size based upsampling for odd sizes
     # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
@@ -91,8 +85,6 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
             x = self.pred(x)
 
             if self.model_config.final_upsample:
-                # final prediction is the upsampled one
-                #scale_factor = (in_shape[2]/x.shape[2], in_shape[3]/x.shape[3])
                 x = self.upsample2((x, x_input))
 
             if (not self.training) and (self.model_config.output_type == 'segmentation'):
similarity index 83%
rename from modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py
rename to modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpnlite_pixel2pixel.py
index 8bed13818f8c534a06f3ac3e96345f699e840896..d3aeb263d04b93c9fec4644c51a40377c7fa8c86 100644 (file)
@@ -6,16 +6,16 @@ from .pixel2pixelnet import *
 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
 
 
-__all__ = ['FPNPixel2PixelASPP', 'FPNPixel2PixelDecoder',
-           'fpn_pixel2pixel_aspp_mobilenetv2_tv', 'fpn_pixel2pixel_aspp_mobilenetv2_tv_fd', 'fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd',
+__all__ = ['FPNLitePixel2PixelASPP', 'FPNLitePixel2PixelDecoder',
+           'fpnlite_pixel2pixel_aspp_mobilenetv2_tv', 'fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd', 'fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd',
            # no aspp models
-           'fpn_pixel2pixel_mobilenetv2_tv', 'fpn_pixel2pixel_mobilenetv2_tv_fd',
+           'fpnlite_pixel2pixel_mobilenetv2_tv', 'fpnlite_pixel2pixel_mobilenetv2_tv_fd',
            # resnet models
-           'fpn_pixel2pixel_aspp_resnet50', 'fpn_pixel2pixel_aspp_resnet50_fd',
+           'fpnlite_pixel2pixel_aspp_resnet50', 'fpnlite_pixel2pixel_aspp_resnet50_fd',
            ]
 
 # config settings for mobilenetv2 backbone
-def get_config_fpnp2p_mnv2():
+def get_config_fpnlitep2p_mnv2():
     model_config = xnn.utils.ConfigNode()
     model_config.num_classes = None
     model_config.num_decoders = None
@@ -57,7 +57,7 @@ def get_config_fpnp2p_mnv2():
 
 
 ###########################################
-class FPNPyramid(torch.nn.Module):
+class FPNLitePyramid(torch.nn.Module):
     def __init__(self, current_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=False, all_outputs=False):
         super().__init__()
         self.inloop_fpn = inloop_fpn
@@ -120,13 +120,13 @@ class FPNPyramid(torch.nn.Module):
         return outputs[::-1]
 
 
-class InLoopFPNPyramid(FPNPyramid):
+class InLoopFPNLitePyramid(FPNLitePyramid):
     def __init__(self, input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=True, all_outputs=False):
         super().__init__(input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=inloop_fpn, all_outputs=all_outputs)
 
 
 ###########################################
-class FPNPixel2PixelDecoder(torch.nn.Module):
+class FPNLitePixel2PixelDecoder(torch.nn.Module):
     def __init__(self, model_config):
         super().__init__()
         self.model_config = model_config
@@ -154,21 +154,16 @@ class FPNPixel2PixelDecoder(torch.nn.Module):
 
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
-        FPNType = InLoopFPNPyramid if model_config.inloop_fpn else FPNPyramid
+        FPNType = InLoopFPNLitePyramid if model_config.inloop_fpn else FPNLitePyramid
         self.fpn = FPNType(current_channels, decoder_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
                            self.model_config.interpolation_type, self.model_config.interpolation_mode)
 
-        # prediction
+        # add prediction & upsample modules
         if self.model_config.final_prediction:
-            final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) if (model_config.output_range is not None) else False
-            self.pred = xnn.layers.ConvDWSepNormAct2d(current_channels, self.model_config.output_channels, kernel_size=3, normalization=(True,False), activation=(False,final_activation))
-
-            if self.model_config.final_upsample:
-                upstride_final = self.model_config.shortcut_strides[0]
-                self.upsample = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride_final, model_config.interpolation_type, model_config.interpolation_mode)
-            #
+            add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
         #
 
+
     def forward(self, x_input, x, x_list):
         assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
         assert x is x_list[-1], 'the features must the last one in x_list'
@@ -205,19 +200,19 @@ class FPNPixel2PixelDecoder(torch.nn.Module):
 
 
 ###########################################
-class FPNPixel2PixelASPP(Pixel2PixelNet):
+class FPNLitePixel2PixelASPP(Pixel2PixelNet):
     def __init__(self, base_model, model_config):
-        super().__init__(base_model, FPNPixel2PixelDecoder, model_config)
+        super().__init__(base_model, FPNLitePixel2PixelDecoder, model_config)
 
 
 ###########################################
-def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
+def fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
     # encoder setup
     model_config_e = model_config.clone()
     base_model = MobileNetV2TVMI4(model_config_e)
     # decoder setup
-    model = FPNPixel2PixelASPP(base_model, model_config)
+    model = FPNLitePixel2PixelASPP(base_model, model_config)
 
     num_inputs = len(model_config.input_channels)
     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
@@ -239,43 +234,43 @@ def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
 
 
 # fast down sampling model (encoder stride 64 model)
-def fpn_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
+def fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
     model_config.fastdown = True
     model_config.strides = (2,2,2,2,2)
     model_config.shortcut_strides = (8,16,32,64)
     model_config.shortcut_channels = (24,32,96,320)
     model_config.decoder_chan = 256
     model_config.aspp_chan = 256
-    return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+    return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
 
 
 # fast down sampling model (encoder stride 64 model) with fpn decoder channels 128
 def fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
+    model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
     model_config.fastdown = True
     model_config.strides = (2,2,2,2,2)
     model_config.shortcut_strides = (4,8,16,32,64)
     model_config.shortcut_channels = (16,24,32,96,320)
     model_config.decoder_chan = 128
     model_config.aspp_chan = 128
-    return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+    return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
 
 
 ##################
 # similar to the original fpn model with extra convolutions with strides (no aspp)
-def fpn_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
+def fpnlite_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
     model_config.use_aspp = False
     model_config.use_extra_strides = True
     model_config.shortcut_strides = (4, 8, 16, 32, 64, 128)
     model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
-    return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+    return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
 
 
 # similar to the original fpn model with extra convolutions with strides (no aspp) - fast down sampling model (encoder stride 64 model)
-def fpn_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
+def fpnlite_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
     model_config.use_aspp = False
     model_config.use_extra_strides = True
     model_config.fastdown = True
@@ -284,24 +279,24 @@ def fpn_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
     model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
     model_config.decoder_chan = 256
     model_config.aspp_chan = 256
-    return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+    return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
 
 
 ###########################################
-def get_config_fpnp2p_resnet50():
+def get_config_fpnlitep2p_resnet50():
     # only the delta compared to the one defined for mobilenetv2
-    model_config = get_config_fpnp2p_mnv2()
+    model_config = get_config_fpnlitep2p_mnv2()
     model_config.shortcut_channels = (256,512,1024,2048)
     return model_config
 
 
-def fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
+def fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_resnet50().merge_from(model_config)
     # encoder setup
     model_config_e = model_config.clone()
     base_model = ResNet50MI4(model_config_e)
     # decoder setup
-    model = FPNPixel2PixelASPP(base_model, model_config)
+    model = FPNLitePixel2PixelASPP(base_model, model_config)
 
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
@@ -333,12 +328,12 @@ def fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
     return model, change_names_dict
 
 
-def fpn_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
-    model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
+def fpnlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
+    model_config = get_config_fpnlitep2p_resnet50().merge_from(model_config)
     model_config.fastdown = True
     model_config.strides = (2,2,2,2,2)
     model_config.shortcut_strides = (8,16,32,64) #(4,8,16,32,64)
     model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
-    return fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+    return fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
index bcfa506221ef7218ad42ef57b9c730ca5e1b1e84..9b8780271c8f6698a00b1213a651c47c3388fb68 100644 (file)
@@ -1,7 +1,6 @@
 import torch
 from .... import xnn
-import torch.nn.functional as F
-import copy
+from .pixel2pixelnet_utils import *
 
 
 ###########################################
@@ -93,4 +92,3 @@ class Pixel2PixelNet(torch.nn.Module):
         x_out = xnn.layers.split_output_channels(x_out[0], self.output_channels) if (self.num_decoders <= 1 and self.split_outputs) else x_out
         return x_out
 
-
diff --git a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py
new file mode 100644 (file)
index 0000000..5fab7c0
--- /dev/null
@@ -0,0 +1,41 @@
+from .... import xnn
+
+# add prediction and final upsample blocks to pixel2pixel models
+def add_lite_prediction_modules(self, model_config, current_channels, module_names):
+    # prediction and upsample
+    if self.model_config.final_prediction:
+        # use UpsampleGenericTo() instead of UpsampleTo(), to break down large upsampling factors to multiples of 4 and 2
+        # useful if upsampling factors other than 4 and 2 are not supported.
+        UpsampleClass = xnn.layers.UpsampleTo
+
+        # can control the range of final output with output_range
+        final_activation = xnn.layers.get_fixed_pact2(output_range=model_config.output_range) if (model_config.output_range is not None) else False
+        upstride2 = model_config.shortcut_strides[0]
+
+        if self.model_config.final_upsample and self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
+            # some of the upsample blocks have conv or deconv layers
+            # since these blocks can be used to do prediction as well, perform both prediction and upsample together in these cases.
+            # otherwise, using these conv/decon layers after the prediction (with very few channels) will make them difficult to train.
+            pred = xnn.layers.BypassBlock()
+            setattr(self, module_names[0], pred)
+
+            upsample2 = UpsampleClass(current_channels, model_config.output_channels, upstride2,
+                                    model_config.interpolation_type, model_config.interpolation_mode,
+                                    is_final_layer=True, final_activation=final_activation)
+            setattr(self, module_names[1], upsample2)
+        else:
+            # prediction followed by conventional interpolation
+            ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
+            pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
+                                       normalization=((not model_config.linear_dw),False),
+                                       activation=(False,final_activation), groups=1)
+            setattr(self, module_names[0], pred)
+
+            if self.model_config.final_upsample:
+                upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
+                                    model_config.interpolation_type, model_config.interpolation_mode,
+                                    is_final_layer=True, final_activation=final_activation)
+                setattr(self, module_names[1], upsample2)
+            #
+        #
+    #
similarity index 83%
rename from modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unet_pixel2pixel.py
rename to modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unetlite_pixel2pixel.py
index 94b8204c1af96e09c65597adf1f535fd07aaf6d1..b00e8b0beebe107d5177b96786bd9f2b72ec2156 100644 (file)
@@ -6,13 +6,13 @@ from .pixel2pixelnet import *
 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
 
 
-__all__ = ['UNetPixel2PixelASPP', 'UNetPixel2PixelDecoder',
-           'unet_pixel2pixel_aspp_mobilenetv2_tv', 'unet_pixel2pixel_aspp_mobilenetv2_tv_fd',
-           'unet_pixel2pixel_aspp_resnet50', 'unet_pixel2pixel_aspp_resnet50_fd',
+__all__ = ['UNetLitePixel2PixelASPP', 'UNetLitePixel2PixelDecoder',
+           'unetlite_pixel2pixel_aspp_mobilenetv2_tv', 'unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd',
+           'unetlite_pixel2pixel_aspp_resnet50', 'unetlite_pixel2pixel_aspp_resnet50_fd',
            ]
 
 # config settings for mobilenetv2 backbone
-def get_config_unetp2p_mnv2():
+def get_config_unetlitep2p_mnv2():
     model_config = xnn.utils.ConfigNode()
     model_config.num_classes = None
     model_config.num_decoders = None
@@ -52,7 +52,7 @@ def get_config_unetp2p_mnv2():
 
 
 ###########################################
-class UNetPyramid(torch.nn.Module):
+class UNetLitePyramid(torch.nn.Module):
     def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode):
         super().__init__()
         self.shortcut_strides = shortcut_strides
@@ -103,7 +103,7 @@ class UNetPyramid(torch.nn.Module):
 
 
 ###########################################
-class UNetPixel2PixelDecoder(torch.nn.Module):
+class UNetLitePixel2PixelDecoder(torch.nn.Module):
     def __init__(self, model_config):
         super().__init__()
         self.model_config = model_config
@@ -132,21 +132,16 @@ class UNetPixel2PixelDecoder(torch.nn.Module):
         minimum_channels = max(self.model_config.output_channels*2, 32)
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
-        self.unet = UNetPyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
+        self.unet = UNetLitePyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
                            self.model_config.interpolation_type, self.model_config.interpolation_mode)
         current_channels = max(minimum_channels, shortcut_channels[-1])
 
-        # prediction
+        # add prediction & upsample modules
         if self.model_config.final_prediction:
-            final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) if (model_config.output_range is not None) else False
-            self.pred = xnn.layers.ConvDWSepNormAct2d(current_channels, self.model_config.output_channels, kernel_size=3, normalization=(True,False), activation=(False,final_activation))
-
-            if self.model_config.final_upsample:
-                upstride_final = self.model_config.shortcut_strides[0]
-                self.upsample = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride_final, model_config.interpolation_type, model_config.interpolation_mode)
-            #
+            add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
         #
 
+
     def forward(self, x_input, x, x_list):
         assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
         assert x is x_list[-1], 'the features must the last one in x_list'
@@ -183,19 +178,19 @@ class UNetPixel2PixelDecoder(torch.nn.Module):
 
 
 ###########################################
-class UNetPixel2PixelASPP(Pixel2PixelNet):
+class UNetLitePixel2PixelASPP(Pixel2PixelNet):
     def __init__(self, base_model, model_config):
-        super().__init__(base_model, UNetPixel2PixelDecoder, model_config)
+        super().__init__(base_model, UNetLitePixel2PixelDecoder, model_config)
 
 
 ###########################################
-def unet_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
-    model_config = get_config_unetp2p_mnv2().merge_from(model_config)
+def unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
+    model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
     # encoder setup
     model_config_e = model_config.clone()
     base_model = MobileNetV2TVMI4(model_config_e)
     # decoder setup
-    model = UNetPixel2PixelASPP(base_model, model_config)
+    model = UNetLitePixel2PixelASPP(base_model, model_config)
 
     num_inputs = len(model_config.input_channels)
     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
@@ -217,33 +212,33 @@ def unet_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
 
 
 # fast down sampling model (encoder stride 64 model)
-def unet_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
-    model_config = get_config_unetp2p_mnv2().merge_from(model_config)
+def unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
+    model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
     model_config.fastdown = True
     model_config.strides = (2,2,2,2,2)
     model_config.shortcut_strides = (4,8,16,32,64)
     model_config.shortcut_channels = (16,24,32,96,320)
     model_config.decoder_chan = 256
     model_config.aspp_chan = 256
-    return unet_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+    return unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
 
 
 ###########################################
-def get_config_unetp2p_resnet50():
+def get_config_unetlitep2p_resnet50():
     # only the delta compared to the one defined for mobilenetv2
-    model_config = get_config_unetp2p_mnv2()
+    model_config = get_config_unetlitep2p_mnv2()
     model_config.shortcut_strides = (2,4,8,16,32)
     model_config.shortcut_channels = (64,256,512,1024,2048)
     return model_config
 
 
-def unet_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
-    model_config = get_config_unetp2p_resnet50().merge_from(model_config)
+def unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
+    model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
     # encoder setup
     model_config_e = model_config.clone()
     base_model = ResNet50MI4(model_config_e)
     # decoder setup
-    model = UNetPixel2PixelASPP(base_model, model_config)
+    model = UNetLitePixel2PixelASPP(base_model, model_config)
 
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
@@ -275,12 +270,12 @@ def unet_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
     return model, change_names_dict
 
 
-def unet_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
-    model_config = get_config_unetp2p_resnet50().merge_from(model_config)
+def unetlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
+    model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
     model_config.fastdown = True
     model_config.strides = (2,2,2,2,2)
     model_config.shortcut_strides = (2,4,8,16,32,64) #(4,8,16,32,64)
     model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
-    return unet_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+    return unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
index b4205f6a2bd15b6062e7e1751d8469074fcf98f7..bf977cf99ec4ffc3491528c3ba7c6da0bab225e8 100644 (file)
@@ -1,17 +1,16 @@
+from .model_utils import *
+from .import functional
 from .normalization import *
 from .activation import *
-from .common_blocks import *
+from .layer_config import *
 
+from .common_blocks import *
 from .conv_blocks import *
-from .upsample_blocks import *
-from .multi_task import *
+from .deconv_blocks import *
+from .resize_blocks import *
 
+from .multi_task import *
 from .rf_blocks import *
-from .import functional
-
-from .layer_config import *
-
-from .model_utils import *
 
 # optional/experimental
 try:
index 0c8448d67712e8e2b582f62ddb4fa6b85859bd27..8f9c6577943fb879091b8251a5648deb95147c95 100644 (file)
@@ -164,22 +164,6 @@ class ParallelBlock(torch.nn.Module):
         return x
 
 
-###############################################################
-# Resize to the target size
-class ResizeTo(torch.nn.Module):
-    def __init__(self, mode):
-        super().__init__()
-        self.mode = mode
-
-    def forward(self, input):
-        assert isinstance(input, (list,tuple)), 'must provide two tensors - input and target'
-        x = input[0]
-        xt = input[1]
-        target_size = (int(xt.size(2)), int(xt.size(3)))
-        y = torch.nn.functional.interpolate(x, size=target_size, mode=self.mode)
-        return y
-
-
 
 ###############################################################
 class ShuffleBlock(torch.nn.Module):
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/deconv_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/deconv_blocks.py
new file mode 100644 (file)
index 0000000..f5faa9b
--- /dev/null
@@ -0,0 +1,94 @@
+import torch
+from .conv_blocks import *
+from .layer_config import *
+from .common_blocks import *
+
+###############################################################
+def DeConvLayer2d(in_planes, out_planes, kernel_size, stride=1, groups=1, dilation=1, padding=None, output_padding=None,
+                  bias=False):
+    """convolution with padding"""
+    if (output_padding is None) and (padding is None):
+        if kernel_size % 2 == 0:
+            padding = (kernel_size - stride) // 2
+            output_padding = 0
+        else:
+            padding = (kernel_size - stride + 1) // 2
+            output_padding = 1
+
+    return torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
+                                    padding=padding,
+                                    output_padding=output_padding, bias=bias, groups=groups)
+
+
+def DeConvDWLayer2d(in_planes, out_planes, stride=1, dilation=1, kernel_size=None, padding=None, output_padding=None,
+                    bias=False):
+    """convolution with padding"""
+    assert in_planes == out_planes, 'in DW layer channels must not change'
+    return DeConvLayer2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
+                         groups=in_planes,
+                         padding=padding, output_padding=output_padding, bias=bias)
+
+
+###############################################################
+def DeConvNormAct2d(in_planes, out_planes, kernel_size=None, stride=1, groups=1, dilation=1, padding=None,
+                    output_padding=None, bias=False, \
+                    normalization=DefaultNorm2d, activation=DefaultAct2d):
+    """convolution with padding, BN, ReLU"""
+    if (output_padding is None) and (padding is None):
+        if kernel_size % 2 == 0:
+            padding = (kernel_size - stride) // 2
+            output_padding = 0
+        else:
+            padding = (kernel_size - stride + 1) // 2
+            output_padding = 1
+
+    if activation is True:
+        activation = DefaultAct2d
+
+    if normalization is True:
+        normalization = DefaultNorm2d
+
+    layers = [torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
+                                       padding=padding,
+                                       output_padding=output_padding, bias=bias, groups=groups)]
+    if normalization:
+        layers.append(normalization(out_planes))
+
+    if activation:
+        layers.append(activation(inplace=True))
+    #
+    layers = torch.nn.Sequential(*layers)
+    return layers
+
+
+def DeConvDWNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, dilation=1, padding=None, output_padding=None,
+                      bias=False,
+                      normalization=DefaultNorm2d, activation=DefaultAct2d):
+    """convolution with padding, BN, ReLU"""
+    assert in_planes == out_planes, 'in DW layer channels must not change'
+    return DeConvNormAct2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
+                           padding=padding, output_padding=output_padding,
+                           bias=bias, groups=in_planes, normalization=normalization, activation=activation)
+
+
+###########################################################
+def DeConvDWSepNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
+                         first_1x1=False, normalization=(DefaultNorm2d, DefaultNorm2d),
+                         activation=(DefaultAct2d, DefaultAct2d)):
+    if first_1x1:
+        layers = [
+            ConvNormAct2d(in_planes, out_planes, kernel_size=1, groups=groups, bias=bias,
+                          normalization=normalization[0], activation=activation[0]),
+            DeConvDWNormAct2d(out_planes, out_planes, stride=stride, kernel_size=kernel_size, dilation=dilation,
+                              bias=bias,
+                              normalization=normalization[1], activation=activation[1])]
+    else:
+        layers = [DeConvDWNormAct2d(in_planes, in_planes, stride=stride, kernel_size=kernel_size, dilation=dilation,
+                                    bias=bias,
+                                    normalization=normalization[0], activation=activation[0]),
+                  ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
+                                normalization=normalization[1], activation=activation[1])]
+
+    layers = torch.nn.Sequential(*layers)
+    return layers
+
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py
new file mode 100644 (file)
index 0000000..57fee25
--- /dev/null
@@ -0,0 +1,90 @@
+from .deconv_blocks import *
+
+###############################################################
+class ResizeTo(torch.nn.Module):
+    def __init__(self, mode='bilinear'):
+        '''
+            Resize to the target size
+        '''
+        super().__init__()
+        self.mode = mode
+
+    def forward(self, input):
+        assert isinstance(input, (list,tuple)), 'must provide two tensors - input and target'
+        x = input[0]
+        xt = input[1]
+        target_size = (int(xt.size(2)), int(xt.size(3)))
+        y = torch.nn.functional.interpolate(x, size=target_size, mode=self.mode)
+        return y
+
+
+###############################################################
+def UpsampleTo(input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
+               is_final_layer=False, final_activation=True):
+    '''
+         A Resize module that breaks downscale factors > 4 to multiples of 2 and 4
+         is_final_layer: Final layer in a pixel2pixel network should not typically use BatchNorm for best accuracy
+         final_activation: use to control the range of final layer if required
+     '''
+    if interpolation_type == 'upsample':
+        upsample = ResizeTo(mode=interpolation_mode)
+    else:
+        assert upstride is not None, 'upstride must not be None in this interpolation_mode'
+        assert input_channels is not None, 'input_channels must not be None in this interpolation_mode'
+        assert output_channels is not None, 'output_channels must not be None in this interpolation_mode'
+        final_norm = (False if is_final_layer else True)
+        normalization = (True, final_norm)
+        activation = (False, final_activation)
+        if interpolation_type == 'deconv':
+            upsample = [SplitListTakeFirst(),
+                        DeConvDWSepNormAct2d(input_channels, output_channels, kernel_size=upstride * 2, stride=upstride,
+                                      normalization=normalization, activation=activation)]
+        elif interpolation_type == 'upsample_conv':
+            upsample = [ResizeTo(mode=interpolation_mode),
+                        ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
+                                      normalization=normalization, activation=activation)]
+        elif interpolation_type == 'subpixel_conv':
+            upsample = [SplitListTakeFirst(),
+                        ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1),
+                                      normalization=normalization, activation=activation),
+                        torch.nn.PixelShuffle(upscale_factor=int(upstride))]
+        else:
+            assert False, f'invalid interpolation_type: {interpolation_type}'
+        #
+        upsample = torch.nn.Sequential(*upsample)
+    #
+
+    return upsample
+
+
+class UpsampleGenericTo(torch.nn.Module):
+    def __init__(self, input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
+                 is_final_layer=False, final_activation=False):
+        '''
+            A Resize module that breaks downscale factors > 4 to multiples of 2 and 4
+            is_final_layer: Final layer in a pixel2pixel network should not typically use BatchNorm
+            final_activation: use to control the range of final layer if required
+        '''
+        super().__init__()
+        self.upsample_list = torch.nn.ModuleList()
+        self.upstride_list = []
+        while upstride >= 2:
+            upstride_layer = 4 if upstride > 4 else upstride
+            upsample = UpsampleTo(input_channels, output_channels, upstride_layer, interpolation_type, interpolation_mode,
+                                  is_final_layer=is_final_layer, final_activation=final_activation)
+            self.upsample_list.append(upsample)
+            self.upstride_list.append(upstride_layer)
+            upstride = upstride//4
+
+    def forward(self, x):
+        assert isinstance(x, (list,tuple)) and len(x)==2, 'input must be a tuple/list of size 2'
+        x, x_target = x
+        xt_shape = x.shape
+        for idx, (upsample, upstride) in enumerate(zip(self.upsample_list,self.upstride_list)):
+            xt_shape = (xt_shape[0], xt_shape[1], xt_shape[2]*upstride, xt_shape[3]*upstride)
+            xt = torch.zeros(xt_shape).to(x.device)
+            x = upsample((x, xt))
+            xt_shape = x.shape
+        #
+        return x
+
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py
deleted file mode 100644 (file)
index 20bc2e4..0000000
+++ /dev/null
@@ -1,129 +0,0 @@
-from .conv_blocks import *
-from .layer_config import *
-from .common_blocks import *
-
-###############################################################
-def UpsampleTo(input_channels, output_channels, upstride, interpolation_type, interpolation_mode):
-    upsample = []
-    if interpolation_type == 'upsample':
-        upsample = [ResizeTo(mode=interpolation_mode)]
-    elif interpolation_type == 'deconv':
-        upsample = [SplitListTakeFirst(),
-                    DeConvDWLayer2d(input_channels, output_channels, kernel_size=upstride * 2, stride=upstride)]
-    elif interpolation_type == 'upsample_conv':
-        upsample = [ResizeTo(mode=interpolation_mode),
-                    ConvDWLayer2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1))]
-    elif interpolation_type == 'subpixel_conv':
-        upsample = [SplitListTakeFirst(),
-                    ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1), normalization=(True,False), activation=(False,False)),
-                    torch.nn.PixelShuffle(upscale_factor=upstride)]
-    else:
-        assert False, f'invalid interpolation_type: {interpolation_type}'
-    #
-    upsample = torch.nn.Sequential(*upsample)
-    return upsample
-
-
-class UpsampleGenericTo(torch.nn.Module):
-    def __init__(self, input_channels, output_channels, upstride, interpolation_type, interpolation_mode):
-        super().__init__()
-        self.upsample_list = torch.nn.ModuleList()
-        self.upstride_list = []
-        while upstride >= 2:
-            upstride_layer = 4 if upstride > 4 else upstride
-            upsample = UpsampleTo(input_channels, output_channels, upstride_layer, interpolation_type, interpolation_mode)
-            self.upsample_list.append(upsample)
-            self.upstride_list.append(upstride_layer)
-            upstride = upstride//4
-
-    def forward(self, x):
-        assert isinstance(x, (list,tuple)) and len(x)==2, 'input must be a tuple/list of size 2'
-        x, x_target = x
-        xt_shape = x.shape
-        for idx, (upsample, upstride) in enumerate(zip(self.upsample_list,self.upstride_list)):
-            xt_shape = (xt_shape[0], xt_shape[1], xt_shape[2]*upstride, xt_shape[3]*upstride)
-            xt = torch.zeros(xt_shape).to(x.device)
-            x = upsample((x, xt))
-            xt_shape = x.shape
-        #
-        return x
-
-
-############################################################### 
-def DeConvLayer2d(in_planes, out_planes, kernel_size, stride=1, groups=1, dilation=1, padding=None, output_padding=None, bias=False):
-    """convolution with padding"""
-    if (output_padding is None) and (padding is None):
-        if kernel_size % 2 == 0:
-            padding = (kernel_size - stride) // 2
-            output_padding = 0
-        else:
-            padding = (kernel_size - stride + 1) // 2
-            output_padding = 1
-
-    return torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding,
-                                    output_padding=output_padding, bias=bias, groups=groups)
-
-
-def DeConvDWLayer2d(in_planes, out_planes, stride=1, dilation=1, kernel_size=None, padding=None, output_padding=None, bias=False):
-    """convolution with padding"""
-    assert in_planes == out_planes, 'in DW layer channels must not change'
-    return DeConvLayer2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=in_planes,
-                       padding=padding, output_padding=output_padding, bias=bias)
-    
-
-############################################################### 
-def DeConvBNAct(in_planes, out_planes, kernel_size=None, stride=1, groups=1, dilation=1, padding=None, output_padding=None, bias=False, \
-              normalization=DefaultNorm2d, activation=DefaultAct2d):
-    """convolution with padding, BN, ReLU"""
-    if (output_padding is None) and (padding is None):
-        if kernel_size % 2 == 0:
-            padding = (kernel_size - stride) // 2
-            output_padding = 0
-        else:
-            padding = (kernel_size - stride + 1) // 2
-            output_padding = 1
-
-    if activation is True:
-        activation = DefaultAct2d
-
-    if normalization is True:
-        normalization = DefaultNorm2d
-
-    layers = [torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding,
-                                       output_padding=output_padding, bias=bias, groups=groups)]
-    if normalization:
-        layers.append(normalization(out_planes))
-
-    if activation:
-        layers.append(activation(inplace=True))
-    #
-    layers = torch.nn.Sequential(*layers)
-    return layers
-
-    
-def DeConvDWBNAct(in_planes, out_planes, stride=1, kernel_size=None, dilation=1, padding=None, output_padding=None, bias=False,
-                  normalization=DefaultNorm2d, activation=DefaultAct2d):
-    """convolution with padding, BN, ReLU"""
-    assert in_planes == out_planes, 'in DW layer channels must not change'
-    return DeConvBNAct(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding,
-                       bias=bias, groups=in_planes, normalization=normalization, activation=activation)
-
-
-###########################################################
-def DeConvDWSepBNAct(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
-                   first_1x1=False, normalization=(DefaultNorm2d,DefaultNorm2d), activation=(DefaultAct2d,DefaultAct2d)):
-    if first_1x1:
-        layers = [
-            ConvNormAct2d(in_planes, out_planes, kernel_size=1, groups=groups, bias=bias,
-                      normalization=normalization[0], activation=activation[0]),
-            DeConvDWBNAct(out_planes, out_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
-                        normalization=normalization[1], activation=activation[1])]
-    else:
-        layers = [DeConvDWBNAct(in_planes, in_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
-                              normalization=normalization[0], activation=activation[0]),
-                  ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
-                            normalization=normalization[1], activation=activation[1])]
-
-    layers = torch.nn.Sequential(*layers)
-    return layers
-
index e491a57a6aee6d03be53f170fd04e7d27baf4015..9a0161d11c3f2e5f30bb03a1a7e92f74c32196de 100644 (file)
@@ -310,7 +310,7 @@ class QuantGraphModule(HookedModule):
         #
     #
     def _merge_weight_op(self, module_hash, module, qparams, make_backup):
-        is_conv = isinstance(module,torch.nn.Conv2d)
+        is_conv = utils.is_conv_deconv(module)
 
         # note: we consider merging only if there is a single next node
         next_module = qparams.next_module[0] if len(qparams.next_module) == 1 else None
index fc94ca777ff6572acfb60f6dd33e0829f3105b21..b1a882d07534ff19fcbc606038651f956a97bcbd 100644 (file)
@@ -78,14 +78,17 @@ class QuantTrainModule(QuantBaseModule):
 
         def replace_func(op):
             for name, m in op._modules.items():
-                if isinstance(m,(torch.nn.Conv2d)):
+                if utils.is_conv(m):
                     bias = (m.bias is not None)
                     padding_mode = m.padding_mode
                     new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
                                             padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
-                elif isinstance(m,(torch.nn.ConvTranspose2d)):
-                    assert False, 'TODO: handle ConvTranspose2d in replace_quant_modules()'
-                elif isinstance(m,(torch.nn.BatchNorm2d)):
+                elif utils.is_deconv(m):
+                    bias = (m.bias is not None)
+                    padding_mode = m.padding_mode
+                    new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
+                                            padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
+                elif utils.is_bn(m):
                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
                                             track_running_stats=m.track_running_stats)
                 elif isinstance(m, layers.PAct2):
index 5db807215903ac5b14ea21573d6b7af62a4695a3..c91a6aa2efdbce680c9cb1bda73093158cbcc5fe 100644 (file)
@@ -24,8 +24,6 @@ def is_merged_layer(x):
 
 
 ###########################################################
-# convolution with fake quantization of weights
-# ideally quantization has to be done on batch norm folded weights.
 class QuantTrainConv2d(torch.nn.Conv2d):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -56,6 +54,37 @@ class QuantTrainConv2d(torch.nn.Conv2d):
     #
 
 
+###########################################################
+class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.quantize_enable = True
+        self.bitwidth_weights = None
+        self.bitwidth_activations = None
+        self.per_channel_q = False
+
+    def forward(self, x):
+        is_merged = is_merged_layer(x)
+        if is_merged:
+           warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
+        #
+
+        y = super().forward(x)
+
+        if not self.quantize_enable:
+            # if quantization is disabled - return
+            return y
+        #
+
+        qparams = get_qparams()
+        qparams.inputs.append(x)
+        qparams.modules.append(self)
+        y.qparams = qparams
+        #
+        return y
+    #
+
+
 ###########################################################
 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
     def __init__(self, *args, **kwargs):
@@ -71,7 +100,7 @@ class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
             return y
         #
 
-        if is_merged_layer(x) and isinstance(x.qparams.modules[-1], torch.nn.Conv2d):
+        if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
             qparams = get_qparams()
             qparams.inputs = [x.qparams.inputs[0], x]
             qparams.modules = [x.qparams.modules[0], self]
@@ -140,8 +169,11 @@ class QuantTrainPAct2(layers.PAct2):
             conv, weight, bias = None, None, None
         #
 
-        if is_merged and conv is not None:
-            xq = torch.nn.functional.conv2d(xorg, weight, bias, conv.stride, conv.padding, conv.dilation, conv.groups)
+        if is_merged and utils.is_conv(conv):
+            xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
+        elif is_merged and utils.is_deconv(conv):
+            xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
+                                                      dilation=conv.dilation, groups=conv.groups)
         else:
             xq = x
         #
@@ -191,7 +223,7 @@ class QuantTrainPAct2(layers.PAct2):
 
         conv, bn = None, None
         # merge weight and bias (if possible) across layers
-        if len(qparams.modules) == 2 and isinstance(qparams.modules[-2], torch.nn.Conv2d) and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
+        if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
             conv = qparams.modules[-2]
             conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
             #
@@ -206,7 +238,7 @@ class QuantTrainPAct2(layers.PAct2):
             merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale.sign()
             merged_scale_inv = 1.0 / merged_scale_eps
             #
-        elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.Conv2d):
+        elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
             conv = qparams.modules[-1]
             merged_weight = conv.weight
             merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
index bf410acbfb02babcad6cbaa75dff51c883917710..144fd2e547701d97321a65690c76a79bad33c95e 100644 (file)
@@ -1,5 +1,5 @@
 import torch
-
+from . import module_utils
 
 def forward_count_flops(module, inp):
     _add_hook(module, _count_flops_func)
@@ -17,7 +17,7 @@ def _count_flops_func(m, inp, out):
     if isinstance(out, (list,tuple)):
         out = out[0]
     #
-    if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
+    if module_utils.is_conv_deconv(m):
         num_pixels = (out.shape[2] * out.shape[3])
         # Note: channels_in taken from weight shape is already divided by m.groups - no need to divide again
         channels_out, channels_in, kernel_height, kernel_width = m.weight.shape
index 63137a1f8af0a7fb420059c404693a64cbb9ac95..2e048aad99d578e36237f2de8493463c3a453036 100644 (file)
@@ -13,7 +13,7 @@ def is_activation(module):
                                  layers.PAct2, layers.ReLUN))
     return is_act
 
-def is_pact(module):
+def is_pact2(module):
     is_act = isinstance(module, (layers.PAct2))
     return is_act
 
index 9bf0f6e4fbc85e5f14706976a6bf1c27e6fd8554..4bba7915484d336876b9bb45c896d38156817268 100644 (file)
@@ -1,9 +1,10 @@
 import torch
-from ..layers.functional import ceil2_g
 
 def module_weights_init(module):
     # weight initialization
     for m in module.modules():
+        # not ConvTranspose2d is not handled here.
+        # let pytorch's default initialization be used for now.
         if isinstance(m, torch.nn.Conv2d):
             torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
             if m.bias is not None:
index f96d1afefc010b755927e6f788793cffa67937ab..a7c32dea91f717d6dca4cbd1e17eef2d83c9c803 100755 (executable)
@@ -9,7 +9,7 @@
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### KITTI Depth (Manual Download) - Training with ResNet50+FPN
-#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \
+#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name fpnlite_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
 ## =====================================================================================
index c1b6d670cfa7199f8744f545b653e042b3766e7b..d96b700545c6c31364f8c90713a9634d2f9a98e8 100755 (executable)
@@ -5,15 +5,15 @@
 # Models Supported:
 ## =====================================================================================
 # deeplabv3lite_mobilenetv2_tv: deeplabv3lite decoder
-# fpn_pixel2pixel_aspp_mobilenetv2_tv: fpn decoder
-# unet_pixel2pixel_aspp_mobilenetv2_tv: unet decoder
+# fpnlite_pixel2pixel_aspp_mobilenetv2_tv: fpn decoder
+# unetlite_pixel2pixel_aspp_mobilenetv2_tv: unet decoder
 # deeplabv3lite_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
-# fpn_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
-# unet_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
+# fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
+# unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
 #
 # deeplabv3lite_resnet50: uses resnet50 encoder
 # deeplabv3lite_resnet50_p5: low complexity model - uses resnet50 encoder with half the number of channels (1/4 the complexity). note this need specially trained resnet50 pretrained weights
-# fpn_pixel2pixel_aspp_resnet50_fd: low complexity model - with fast downsampling strategy
+# fpnlite_pixel2pixel_aspp_resnet50_fd: low complexity model - with fast downsampling strategy
 
 
 ## =====================================================================================
index d3694479d4b9a3e4aec2ac710fd68e4a44d5a9d4..bf4d74ce4f325cc09d1d99eaeb497e686e07167f 100755 (executable)
@@ -67,7 +67,7 @@ args = train_pixel2pixel.get_config()
 
 ################################
 #Modify arguments
-args.model_name = 'deeplabv3lite_mobilenetv2_tv' #'deeplabv3lite_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_resnet50'
+args.model_name = 'deeplabv3lite_mobilenetv2_tv' #'deeplabv3lite_mobilenetv2_tv' #'fpnlite_pixel2pixel_aspp_mobilenetv2_tv' #'fpnlite_pixel2pixel_aspp_resnet50'
 
 args.dataset_name = 'kitti_depth' #'kitti_depth' #'kitti_depth' #'kitti_depth2'
 
index 24e47e5d7903830d030a3553212b195ec3f6f6b1..ab54b1cd48d4c662760b12e96b81bdab6eb7748c 100755 (executable)
@@ -67,7 +67,7 @@ args = train_pixel2pixel.get_config()
 
 ################################
 #Modify arguments
-args.model_name = 'deeplabv3lite_mobilenetv2_tv' #'deeplabv3lite_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_resnet50'
+args.model_name = 'deeplabv3lite_mobilenetv2_tv' #'deeplabv3lite_mobilenetv2_tv' #'fpnlite_pixel2pixel_aspp_mobilenetv2_tv' #'fpnlite_pixel2pixel_aspp_resnet50'
 args.dataset_name = 'cityscapes_segmentation' #'cityscapes_segmentation' #'voc_segmentation'
 
 args.data_path = './data/datasets/cityscapes/data' #'./data/datasets/cityscapes/data' #'./data/datasets/voc'