simpler resize/upsample modules using scale_factor
authorManu Mathew <a0393608@ti.com>
Wed, 26 Feb 2020 06:38:34 +0000 (12:08 +0530)
committerManu Mathew <a0393608@ti.com>
Wed, 26 Feb 2020 06:39:15 +0000 (12:09 +0530)
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpnlite_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unetlite_pixel2pixel.py
modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py

index 2662799f5e76577cac426aae4a83ea9d11e8b0a7..fb5837c9fd8387b2d60d49f292896f3f87747efa 100644 (file)
@@ -51,6 +51,7 @@ def get_config():
     args.model_config.freeze_encoder = False            # do not update encoder weights
     args.model_config.freeze_decoder = False            # do not update decoder weights
     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
+    args.model_config.target_input_ratio = 1            # Keep target size same as input size
 
     args.model = None                                   # the model itself can be given from ouside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
@@ -260,6 +261,8 @@ def main(args):
     print('{}'.format(Fore.RESET))
     # print everything for log
     print('=> args: {}'.format(args))
+    print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
+
     print('=> will save everything to {}'.format(save_path))
 
     #################################################
index e6942fd481d1d843d5f4c11305013849b2b2642a..5931e831ce63f3dfb85e9b79e2499e99cc58cd96 100644 (file)
@@ -42,9 +42,9 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
 
         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
-        # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
+        # use UpsampleScaleFactorGeneric() instead of UpsampleScaleFactor() to break down large upsampling factors to multiples of 4 and 2 -
         # useful if upsampling factors other than 4 and 2 are not supported.
-        self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
+        self.upsample1 = xnn.layers.UpsampleScaleFactor(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
 
         self.cat = xnn.layers.CatBlock()
 
@@ -76,7 +76,7 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
         x = self.aspp(x_features) if self.model_config.use_aspp else x_features
 
         # upsample low res features to match with shortcut
-        x = self.upsample1((x, x_s))
+        x = self.upsample1(x)
 
         # combine and do high res prediction
         x = self.cat((x,x_s))
@@ -85,12 +85,12 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
             x = self.pred(x)
 
             if self.model_config.final_upsample:
-                x = self.upsample2((x, x_input))
+                x = self.upsample2(x)
 
             if (not self.training) and (self.model_config.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
-            assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
+            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
 
         if self.model_config.freeze_decoder:
             x = x.detach()
@@ -118,6 +118,7 @@ def get_config_deeplav3lite_mnv2():
     model_config.split_outputs = False
     model_config.use_aspp = True
     model_config.fastdown = False
+    model_config.target_input_ratio = 1
 
     model_config.strides = (2,2,2,2,1)
     model_config.fastdown = False
index d3aeb263d04b93c9fec4644c51a40377c7fa8c86..1ac77418c5ad991f15062d38e891f6001c33cc50 100644 (file)
@@ -25,6 +25,7 @@ def get_config_fpnlitep2p_mnv2():
     model_config.groupwise_sep = False
     model_config.fastdown = False
     model_config.width_mult = 1.0
+    model_config.target_input_ratio = 1
 
     model_config.strides = (2,2,2,2,2)
     encoder_stride = np.prod(model_config.strides)
@@ -81,7 +82,7 @@ class FPNLitePyramid(torch.nn.Module):
             smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
                         if (inloop_fpn or all_outputs or is_last) else None
             self.smooth_convs.append(smooth_conv)
-            upsample = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
+            upsample = xnn.layers.UpsampleScaleFactor(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
             self.upsamples.append(upsample)
         #
     #
@@ -108,7 +109,7 @@ class FPNLitePyramid(torch.nn.Module):
             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
             x_s = shortcut(x_s)
             # updample current output and add to that
-            x = upsample((x,x_s))
+            x = upsample(x)
             x = x + x_s
             # smooth conv
             y = smooth_conv(x) if (smooth_conv is not None) else x
@@ -189,7 +190,7 @@ class FPNLitePixel2PixelDecoder(torch.nn.Module):
 
             # final prediction is the upsampled one
             if self.model_config.final_upsample:
-                x = self.upsample((x,x_input))
+                x = self.upsample(x)
 
             if (not self.training) and (self.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
index 5fab7c07aa274a6dbd65fd9c7ede109c9d3dce72..972f708737930909305f7e7e7de221fe7dbce7fc 100644 (file)
@@ -4,9 +4,9 @@ from .... import xnn
 def add_lite_prediction_modules(self, model_config, current_channels, module_names):
     # prediction and upsample
     if self.model_config.final_prediction:
-        # use UpsampleGenericTo() instead of UpsampleTo(), to break down large upsampling factors to multiples of 4 and 2
-        # useful if upsampling factors other than 4 and 2 are not supported.
-        UpsampleClass = xnn.layers.UpsampleTo
+        # use UpsampleScaleFactorGeneric() instead of UpsampleScaleFactor(), to break down large upsampling factors to multiples of 4 and 2
+        # useful if scale_factor other than 4 and 2 are not supported.
+        UpsampleClass = xnn.layers.UpsampleScaleFactor
 
         # can control the range of final output with output_range
         final_activation = xnn.layers.get_fixed_pact2(output_range=model_config.output_range) if (model_config.output_range is not None) else False
@@ -32,9 +32,14 @@ def add_lite_prediction_modules(self, model_config, current_channels, module_nam
             setattr(self, module_names[0], pred)
 
             if self.model_config.final_upsample:
-                upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
-                                    model_config.interpolation_type, model_config.interpolation_mode,
-                                    is_final_layer=True, final_activation=final_activation)
+                upstride2 = (upstride2//self.model_config.target_input_ratio)
+                if upstride2 > 1:
+                    upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
+                                              model_config.interpolation_type, model_config.interpolation_mode,
+                                              is_final_layer=True, final_activation=final_activation)
+                else:
+                    upsample2 = xnn.layers.BypassBlock()
+                #
                 setattr(self, module_names[1], upsample2)
             #
         #
index b00e8b0beebe107d5177b96786bd9f2b72ec2156..ed51cef35470ab7f803cc704305a5e7249862033 100644 (file)
@@ -22,6 +22,7 @@ def get_config_unetlitep2p_mnv2():
     model_config.groupwise_sep = False
     model_config.fastdown = False
     model_config.width_mult = 1.0
+    model_config.target_input_ratio = 1
 
     model_config.strides = (2,2,2,2,2)
     encoder_stride = np.prod(model_config.strides)
@@ -67,7 +68,7 @@ class UNetLitePyramid(torch.nn.Module):
         upstride = 2
         activation2 = (activation, activation)
         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
-            self.upsamples.append(xnn.layers.UpsampleTo(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
+            self.upsamples.append(xnn.layers.UpsampleScaleFactor(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
             self.concats.append(xnn.layers.CatBlock())
             smooth_channels = max(minimum_channels, feat_chan)
             self.smooth_convs.append( xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels, kernel_size=kernel_size_smooth, activation=activation2))
@@ -92,7 +93,7 @@ class UNetLitePyramid(torch.nn.Module):
             shape_s[1] = short_chan
             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
             # upsample current output and concat to that
-            x = upsample((x,x_s))
+            x = upsample(x)
             x = concat((x,x_s)) if (concat is not None) else x
             # smooth conv
             x = smooth_conv(x) if (smooth_conv is not None) else x
@@ -167,12 +168,12 @@ class UNetLitePixel2PixelDecoder(torch.nn.Module):
 
             # final prediction is the upsampled one
             if self.model_config.final_upsample:
-                x = self.upsample((x,x_input))
+                x = self.upsample(x)
 
             if (not self.training) and (self.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
-            assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
+            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
 
         return x
 
index 57fee250dc8edf096a84c4d0905fda6bcf97a8f1..ac0ee3ce556409a449047ffaf7cdbeb078abcf91 100644 (file)
@@ -1,6 +1,101 @@
+import torch
 from .deconv_blocks import *
 
-###############################################################
+
+##############################################################################################
+# Newer Resize/Upsample mopdules. Please use these modules instead of the older ResizeTo(), UpsampleTo()
+# The older modules may be removed in a later version.
+##############################################################################################
+
+
+# onnx export from PyTorch is creating a complicated graph - use this workaround for now until the onnx export is fixed.
+# only way to create a simple graph with scale _factors seem to be provide size as integer to interpolate function
+# this workaround seems to be working in onnx opset_version=9, however in opset_version=11, it still produces a complicated graph.
+class ResizeScaleFactor(torch.nn.Module):
+    def __init__(self, scale_factor=None, mode='nearest'):
+        ''' Resize with scale_factor
+            This module exports an onnx graph with scale_factor
+        '''
+        super().__init__()
+        self.scale_factor = scale_factor
+        self.mode = mode
+        assert scale_factor is not None, 'scale_factor must be specified'
+
+    def forward(self, x):
+        assert isinstance(x, torch.Tensor), 'must provide a single tensor as input'
+        scale_factor = (self.scale_factor, self.scale_factor) if not isinstance(self.scale_factor, (list,tuple)) else self.scale_factor
+        # generate size as a tuple and pass it - as onnx export inserts scale_factor if the size is a non-tensor
+        # this seems to be the only way to insert scale_factor in onnx export of Upsample/Resize
+        size = (int(x.shape[2]*scale_factor[0]), int(x.shape[3]*scale_factor[1]))
+        y = torch.nn.functional.interpolate(x, size=size, mode=self.mode)
+        return y
+
+
+def UpsampleScaleFactor(input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
+               is_final_layer=False, final_activation=True):
+    '''
+         is_final_layer: Final layer in a pixel2pixel network should not typically use BatchNorm for best accuracy
+         final_activation: use to control the range of final layer if required
+     '''
+    if interpolation_type == 'upsample':
+        upsample = ResizeScaleFactor(scale_factor=upstride, mode=interpolation_mode)
+    else:
+        assert upstride is not None, 'upstride must not be None in this interpolation_mode'
+        assert input_channels is not None, 'input_channels must not be None in this interpolation_mode'
+        assert output_channels is not None, 'output_channels must not be None in this interpolation_mode'
+        final_norm = (False if is_final_layer else True)
+        normalization = (True, final_norm)
+        activation = (False, final_activation)
+        if interpolation_type == 'deconv':
+            upsample = [DeConvDWSepNormAct2d(input_channels, output_channels, kernel_size=upstride * 2, stride=upstride,
+                                      normalization=normalization, activation=activation)]
+        elif interpolation_type == 'upsample_conv':
+            upsample = [ResizeScaleFactor(scale_factor=upstride, mode=interpolation_mode),
+                        ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
+                                      normalization=normalization, activation=activation)]
+        elif interpolation_type == 'subpixel_conv':
+            upsample = [ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1),
+                                      normalization=normalization, activation=activation),
+                        torch.nn.PixelShuffle(upscale_factor=int(upstride))]
+        else:
+            assert False, f'invalid interpolation_type: {interpolation_type}'
+        #
+        upsample = torch.nn.Sequential(*upsample)
+    #
+    return upsample
+
+
+
+class UpsampleScaleFactorGeneric(torch.nn.Module):
+    def __init__(self, input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
+                 is_final_layer=False, final_activation=False):
+        '''
+            A Resize module that breaks downscale factors > 4 to multiples of 2 and 4
+            is_final_layer: Final layer in a pixel2pixel network should not typically use BatchNorm
+            final_activation: use to control the range of final layer if required
+        '''
+        super().__init__()
+        self.upsample_list = torch.nn.ModuleList()
+        while upstride >= 2:
+            upstride_layer = 4 if upstride > 4 else upstride
+            upsample = UpsampleScaleFactor(input_channels, output_channels, upstride_layer, interpolation_type, interpolation_mode,
+                                  is_final_layer=is_final_layer, final_activation=final_activation)
+            self.upsample_list.append(upsample)
+            upstride = upstride//4
+
+    def forward(self, x):
+        assert not isinstance(x, (list,tuple)), 'must provide a single tensor as input'
+        for idx, upsample in enumerate(self.upsample_list):
+            x = upsample(x)
+        #
+        return x
+
+
+
+##############################################################################################
+# The following modules will be deprecated in a later version. Please use the modules above.
+##############################################################################################
+
 class ResizeTo(torch.nn.Module):
     def __init__(self, mode='bilinear'):
         '''
@@ -18,11 +113,9 @@ class ResizeTo(torch.nn.Module):
         return y
 
 
-###############################################################
 def UpsampleTo(input_channels=None, output_channels=None, upstride=None, interpolation_type='upsample', interpolation_mode='bilinear',
                is_final_layer=False, final_activation=True):
     '''
-         A Resize module that breaks downscale factors > 4 to multiples of 2 and 4
          is_final_layer: Final layer in a pixel2pixel network should not typically use BatchNorm for best accuracy
          final_activation: use to control the range of final layer if required
      '''