[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / fpn_pixel2pixel.py
diff --git a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py
index 571e70feaa40c81325db5b46c2f3c5843ad4787b..b24a204e07700fec2b093520b2fc2b4a32245f9b 100644 (file)
model_config.aspp_chan = 256
model_config.aspp_dil = (6,12,18)
- model_config.inloop_fpn = False # inloop_fpn means the smooth convs are in the loop, after upsample
+ model_config.inloop_fpn = True #False # inloop_fpn means the smooth convs are in the loop, after upsample
model_config.kernel_size_smooth = 3
model_config.interpolation_type = 'upsample'
self.shortcut_strides = shortcut_strides
self.shortcut_channels = shortcut_channels
self.smooth_convs = torch.nn.ModuleList()
- self.shortcuts = torch.nn.ModuleList([self.create_shortcut(current_channels, decoder_channels, activation)])
+ self.shortcuts = torch.nn.ModuleList()
self.upsamples = torch.nn.ModuleList()
+ shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) if (current_channels != decoder_channels) else None
+ self.shortcuts.append(shortcut0)
+
+ smooth_conv0 = None #xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation, activation)) if all_outputs else None
+ self.smooth_convs.append(smooth_conv0)
+
upstride = 2
for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
self.shortcuts.append(shortcut)
is_last = (idx == len(shortcut_channels)-1)
- if inloop_fpn or (all_outputs or is_last):
- smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation))
- else:
- smooth_conv = None
- #
+ smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
+ if (inloop_fpn or all_outputs or is_last) else None
self.smooth_convs.append(smooth_conv)
upsample = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
self.upsamples.append(upsample)
return shortcut
#
- def forward(self, x_list, in_shape):
+ def forward(self, x_input, x_list):
+ in_shape = x_input.shape
x = x_list[-1]
- x = self.shortcuts[0](x)
- outputs = [x]
- for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs, self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+
+ outputs = []
+ x = self.shortcuts[0](x) if (self.shortcuts[0] is not None) else x
+ y = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
+ x = y if self.inloop_fpn else x
+ outputs.append(y)
+
+ for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+ # get the feature of lower stride
shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
shape_s[1] = short_chan
x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
x_s = shortcut(x_s)
+ # updample current output and add to that
x = upsample((x,x_s))
x = x + x_s
- if self.inloop_fpn:
- x = smooth_conv(x)
- outputs.append(x)
- elif (smooth_conv is not None):
- y = smooth_conv(x)
- outputs.append(y)
- #
+ # smooth conv
+ y = smooth_conv(x) if (smooth_conv is not None) else x
+ # use smooth output for next level in inloop_fpn
+ x = y if self.inloop_fpn else x
+ # output
+ outputs.append(y)
#
return outputs[::-1]
x_list[-1] = x
#
- x_list = self.fpn(x_list, in_shape)
+ x_list = self.fpn(x_input, x_list)
x = x_list[0]
if self.model_config.final_prediction:
#
if pretrained:
- model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+ model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
#
return model, change_names_dict
#
if pretrained:
- model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+ model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
return model, change_names_dict