]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py
ResizeWith, UpsampleWith classes that can export to onnx with scale_factor in opset_v...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet_utils.py
1 from .... import xnn
3 # add prediction and final upsample blocks to pixel2pixel models
4 def add_lite_prediction_modules(self, model_config, current_channels, module_names):
5     # prediction and upsample
6     if self.model_config.final_prediction:
7         # use UpsampleWithGeneric() instead of UpsampleWith(), to break down large upsampling factors to multiples of 4 and 2
8         # useful if scale_factor other than 4 and 2 are not supported.
9         UpsampleClass = xnn.layers.UpsampleWith
11         # can control the range of final output with output_range
12         final_activation = xnn.layers.get_fixed_pact2(output_range=model_config.output_range) if (model_config.output_range is not None) else False
13         upstride2 = model_config.shortcut_strides[0]
15         if self.model_config.final_upsample and self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
16             # some of the upsample blocks have conv or deconv layers
17             # since these blocks can be used to do prediction as well, perform both prediction and upsample together in these cases.
18             # otherwise, using these conv/decon layers after the prediction (with very few channels) will make them difficult to train.
19             pred = xnn.layers.BypassBlock()
20             setattr(self, module_names[0], pred)
22             upsample2 = UpsampleClass(current_channels, model_config.output_channels, upstride2,
23                                     model_config.interpolation_type, model_config.interpolation_mode,
24                                     is_final_layer=True, final_activation=final_activation)
25             setattr(self, module_names[1], upsample2)
26         else:
27             # prediction followed by conventional interpolation
28             ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
29             pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
30                                        normalization=((not model_config.linear_dw),False),
31                                        activation=(False,final_activation), groups=1)
32             setattr(self, module_names[0], pred)
34             if self.model_config.final_upsample:
35                 upstride2 = (upstride2//self.model_config.target_input_ratio)
36                 if upstride2 > 1:
37                     upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
38                                               model_config.interpolation_type, model_config.interpolation_mode,
39                                               is_final_layer=True, final_activation=final_activation)
40                 else:
41                     upsample2 = xnn.layers.BypassBlock()
42                 #
43                 setattr(self, module_names[1], upsample2)
44             #
45         #
46     #