]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/pixel2pixelnet_utils.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / pixel2pixelnet_utils.py
index c1af102d056abc6b19e5eca4c811b40367a0c2ba..4a3b11ed895dddd2e48659353196e0e9b9eedefb 100644 (file)
@@ -1,12 +1,43 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 from .... import xnn
+import torch
 
 # add prediction and final upsample blocks to pixel2pixel models
-def add_lite_prediction_modules(self, model_config, current_channels, module_names):
+def add_lite_prediction_modules(self, model_config, current_channels, module_names, UpsampleClass = xnn.layers.UpsampleWith):
     # prediction and upsample
     if self.model_config.final_prediction:
         # use UpsampleWithGeneric() instead of UpsampleWith(), to break down large upsampling factors to multiples of 4 and 2
         # useful if scale_factor other than 4 and 2 are not supported.
-        UpsampleClass = xnn.layers.UpsampleWith
 
         # can control the range of final output with output_range
         output_range = model_config.output_range
@@ -14,36 +45,56 @@ def add_lite_prediction_modules(self, model_config, current_channels, module_nam
             if (output_range is not None) else False
         upstride2 = model_config.shortcut_strides[0]
 
-        if self.model_config.final_upsample and self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
-            # some of the upsample blocks have conv or deconv layers
-            # since these blocks can be used to do prediction as well, perform both prediction and upsample together in these cases.
-            # otherwise, using these conv/decon layers after the prediction (with very few channels) will make them difficult to train.
-            pred = xnn.layers.BypassBlock()
-            setattr(self, module_names[0], pred)
-
-            upsample2 = UpsampleClass(current_channels, model_config.output_channels, upstride2,
-                                    model_config.interpolation_type, model_config.interpolation_mode,
-                                    is_final_layer=True, final_activation=final_activation)
-            setattr(self, module_names[1], upsample2)
+        if model_config.prediction_channels is not None and self.model_config.final_upsample and \
+                self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
+            prediction_channels = max(model_config.output_channels, model_config.prediction_channels)
         else:
-            # prediction followed by conventional interpolation
-            ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
-            group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
-            pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
-                                       normalization=((not model_config.linear_dw),False),
-                                       activation=(False,final_activation), groups=1, group_size_dw=group_size_dw)
-            setattr(self, module_names[0], pred)
-
-            if self.model_config.final_upsample:
-                upstride2 = (upstride2//self.model_config.target_input_ratio)
-                if upstride2 > 1:
-                    upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
-                                              model_config.interpolation_type, model_config.interpolation_mode,
-                                              is_final_layer=True, final_activation=final_activation)
-                else:
-                    upsample2 = xnn.layers.BypassBlock()
-                #
-                setattr(self, module_names[1], upsample2)
+            prediction_channels = model_config.output_channels
+        #
+        # prediction followed by conventional interpolation
+        ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
+        group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
+        
+        # Do not use output range yet as this is not final layer. Next upsample has conv layer.
+        final_activation_conv_block = final_activation
+        if (model_config.interpolation_type == 'upsample_dwconv') and self.model_config.final_upsample:
+            final_activation_conv_block = False 
+
+        pred = ConvXWSepBlock(current_channels, prediction_channels, kernel_size=3,
+                                   normalization=((not model_config.linear_dw),False),
+                                   activation=(False,final_activation_conv_block), groups=1, group_size_dw=group_size_dw)
+        setattr(self, module_names[0], pred)
+
+        if self.model_config.final_upsample:
+            upstride2 = (upstride2//self.model_config.target_input_ratio)
+            if upstride2 > 1:
+                upsample2 = UpsampleClass(prediction_channels, model_config.output_channels, upstride2,
+                                          model_config.interpolation_type, model_config.interpolation_mode,
+                                          is_final_layer=True, final_activation=final_activation)
+            else:
+                upsample2 = xnn.layers.BypassBlock()
             #
+            setattr(self, module_names[1], upsample2)
         #
     #
+
+class UpsampleWithSlice(torch.nn.Module):
+    def __init__(self, input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
+               is_final_layer=False, final_activation=True):
+        super().__init__()
+        self.input_channels = input_channels
+        self.slice_size = 1
+        self.upsample2_slice1 = xnn.layers.UpsampleWith(self.slice_size, self.slice_size, upstride,
+                                         interpolation_type, interpolation_mode,
+                                         is_final_layer=is_final_layer, final_activation=final_activation)
+        self.upsample2_slice2 = xnn.layers.UpsampleWith(input_channels - self.slice_size, output_channels - self.slice_size, upstride,
+                                         interpolation_type, interpolation_mode,
+                                         is_final_layer=is_final_layer, final_activation=final_activation)
+        self.cat = xnn.layers.CatBlock()
+
+    def forward(self, x):
+        x1, x2 = x[:, :self.slice_size, ...], x[:, self.slice_size:self.input_channels, :]
+        x1 = self.upsample2_slice1(x1)
+        x2 = self.upsample2_slice2(x2)
+        x = torch.cat((x1, x2), dim=1)
+        return x
\ No newline at end of file