[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet_utils.py
1 from .... import xnn
3 # add prediction and final upsample blocks to pixel2pixel models
4 def add_lite_prediction_modules(self, model_config, current_channels, module_names):
5 # prediction and upsample
6 if self.model_config.final_prediction:
7 # use UpsampleWithGeneric() instead of UpsampleWith(), to break down large upsampling factors to multiples of 4 and 2
8 # useful if scale_factor other than 4 and 2 are not supported.
9 UpsampleClass = xnn.layers.UpsampleWith
11 # can control the range of final output with output_range
12 output_range = model_config.output_range
13 final_activation = xnn.layers.get_fixed_hardtanh_type(output_range[0],output_range[1]) \
14 if (output_range is not None) else False
15 upstride2 = model_config.shortcut_strides[0]
17 if self.model_config.final_upsample and self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
18 # some of the upsample blocks have conv or deconv layers
19 # since these blocks can be used to do prediction as well, perform both prediction and upsample together in these cases.
20 # otherwise, using these conv/decon layers after the prediction (with very few channels) will make them difficult to train.
21 pred = xnn.layers.BypassBlock()
22 setattr(self, module_names[0], pred)
24 upsample2 = UpsampleClass(current_channels, model_config.output_channels, upstride2,
25 model_config.interpolation_type, model_config.interpolation_mode,
26 is_final_layer=True, final_activation=final_activation)
27 setattr(self, module_names[1], upsample2)
28 else:
29 # prediction followed by conventional interpolation
30 ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
31 pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
32 normalization=((not model_config.linear_dw),False),
33 activation=(False,final_activation), groups=1)
34 setattr(self, module_names[0], pred)
36 if self.model_config.final_upsample:
37 upstride2 = (upstride2//self.model_config.target_input_ratio)
38 if upstride2 > 1:
39 upsample2 = UpsampleClass(model_config.output_channels, model_config.output_channels, upstride2,
40 model_config.interpolation_type, model_config.interpolation_mode,
41 is_final_layer=True, final_activation=final_activation)
42 else:
43 upsample2 = xnn.layers.BypassBlock()
44 #
45 setattr(self, module_names[1], upsample2)
46 #
47 #
48 #