986b46643beeaaf9038a9d85f7b14f1638ca846f
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / fpnlite_pixel2pixel.py
1 import torch
2 import numpy as np
3 from .... import xnn
5 from .pixel2pixelnet import *
6 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
9 __all__ = ['FPNLitePixel2PixelASPP', 'FPNLitePixel2PixelDecoder',
10 'fpnlite_pixel2pixel_aspp_mobilenetv2_tv', 'fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd', 'fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd',
11 # no aspp models
12 'fpnlite_pixel2pixel_mobilenetv2_tv', 'fpnlite_pixel2pixel_mobilenetv2_tv_fd',
13 # resnet models
14 'fpnlite_pixel2pixel_aspp_resnet50', 'fpnlite_pixel2pixel_aspp_resnet50_fd',
15 ]
17 # config settings for mobilenetv2 backbone
18 def get_config_fpnlitep2p_mnv2():
19 model_config = xnn.utils.ConfigNode()
20 model_config.num_classes = None
21 model_config.num_decoders = None
22 model_config.intermediate_outputs = True
23 model_config.use_aspp = True
24 model_config.use_extra_strides = False
25 model_config.groupwise_sep = False
26 model_config.fastdown = False
27 model_config.width_mult = 1.0
28 model_config.target_input_ratio = 1
30 model_config.strides = (2,2,2,2,2)
31 encoder_stride = np.prod(model_config.strides)
32 model_config.shortcut_strides = (4,8,16,encoder_stride)
33 model_config.shortcut_channels = (24,32,96,320) # this is for mobilenetv2 - change for other networks
34 model_config.decoder_chan = 256
35 model_config.aspp_chan = 256
36 model_config.aspp_dil = (6,12,18)
38 model_config.inloop_fpn = True #False # inloop_fpn means the smooth convs are in the loop, after upsample
40 model_config.kernel_size_smooth = 3
41 model_config.interpolation_type = 'upsample'
42 model_config.interpolation_mode = 'bilinear'
44 model_config.final_prediction = True
45 model_config.final_upsample = True
46 model_config.output_range = None
48 model_config.normalize_input = False
49 model_config.split_outputs = False
50 model_config.decoder_factor = 1.0
51 model_config.activation = xnn.layers.DefaultAct2d
52 model_config.linear_dw = False
53 model_config.normalize_gradients = False
54 model_config.freeze_encoder = False
55 model_config.freeze_decoder = False
56 model_config.multi_task = False
57 return model_config
60 ###########################################
61 class FPNLitePyramid(torch.nn.Module):
62 def __init__(self, current_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=False, all_outputs=False):
63 super().__init__()
64 self.inloop_fpn = inloop_fpn
65 self.shortcut_strides = shortcut_strides
66 self.shortcut_channels = shortcut_channels
67 self.smooth_convs = torch.nn.ModuleList()
68 self.shortcuts = torch.nn.ModuleList()
69 self.upsamples = torch.nn.ModuleList()
71 shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) if (current_channels != decoder_channels) else None
72 self.shortcuts.append(shortcut0)
74 smooth_conv0 = None #xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation, activation)) if all_outputs else None
75 self.smooth_convs.append(smooth_conv0)
77 upstride = 2
78 for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
79 shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
80 self.shortcuts.append(shortcut)
81 is_last = (idx == len(shortcut_channels)-1)
82 smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
83 if (inloop_fpn or all_outputs or is_last) else None
84 self.smooth_convs.append(smooth_conv)
85 upsample = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
86 self.upsamples.append(upsample)
87 #
88 #
90 def create_shortcut(self, inch, outch, activation):
91 shortcut = xnn.layers.ConvNormAct2d(inch, outch, kernel_size=1, activation=activation)
92 return shortcut
93 #
95 def forward(self, x_input, x_list):
96 in_shape = x_input.shape
97 x = x_list[-1]
99 outputs = []
100 x = self.shortcuts[0](x) if (self.shortcuts[0] is not None) else x
101 y = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
102 x = y if self.inloop_fpn else x
103 outputs.append(y)
105 for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
106 # get the feature of lower stride
107 shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
108 shape_s[1] = short_chan
109 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
110 x_s = shortcut(x_s)
111 # updample current output and add to that
112 x = upsample(x)
113 x = x + x_s
114 # smooth conv
115 y = smooth_conv(x) if (smooth_conv is not None) else x
116 # use smooth output for next level in inloop_fpn
117 x = y if self.inloop_fpn else x
118 # output
119 outputs.append(y)
120 #
121 return outputs[::-1]
124 class InLoopFPNLitePyramid(FPNLitePyramid):
125 def __init__(self, input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=True, all_outputs=False):
126 super().__init__(input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=inloop_fpn, all_outputs=all_outputs)
129 ###########################################
130 class FPNLitePixel2PixelDecoder(torch.nn.Module):
131 def __init__(self, model_config):
132 super().__init__()
133 self.model_config = model_config
134 activation = self.model_config.activation
135 self.output_type = model_config.output_type
136 self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
138 self.rfblock = None
139 if self.model_config.use_aspp:
140 current_channels = self.model_config.shortcut_channels[-1]
141 aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
142 self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
143 current_channels = decoder_channels
144 elif self.model_config.use_extra_strides:
145 # a low complexity pyramid
146 current_channels = self.model_config.shortcut_channels[-3]
147 self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
148 xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
149 current_channels = decoder_channels
150 else:
151 current_channels = self.model_config.shortcut_channels[-1]
152 self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
153 current_channels = decoder_channels
154 #
156 shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
157 shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
158 FPNType = InLoopFPNLitePyramid if model_config.inloop_fpn else FPNLitePyramid
159 self.fpn = FPNType(current_channels, decoder_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
160 self.model_config.interpolation_type, self.model_config.interpolation_mode)
162 # add prediction & upsample modules
163 if self.model_config.final_prediction:
164 add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
165 #
168 def forward(self, x_input, x, x_list):
169 assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
170 assert x is x_list[-1], 'the features must the last one in x_list'
171 x_input = x_input[0]
172 in_shape = x_input.shape
174 if self.model_config.use_extra_strides:
175 for blk in self.rfblock:
176 x = blk(x)
177 x_list += [x]
178 #
179 elif self.rfblock is not None:
180 x = self.rfblock(x)
181 x_list[-1] = x
182 #
184 x_list = self.fpn(x_input, x_list)
185 x = x_list[0]
187 if self.model_config.final_prediction:
188 # prediction
189 x = self.pred(x)
191 # final prediction is the upsampled one
192 if self.model_config.final_upsample:
193 x = self.upsample(x)
195 if (not self.training) and (self.output_type == 'segmentation'):
196 x = torch.argmax(x, dim=1, keepdim=True)
198 assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
200 return x
203 ###########################################
204 class FPNLitePixel2PixelASPP(Pixel2PixelNet):
205 def __init__(self, base_model, model_config):
206 super().__init__(base_model, FPNLitePixel2PixelDecoder, model_config)
209 ###########################################
210 def fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
211 model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
212 # encoder setup
213 model_config_e = model_config.clone()
214 base_model = MobileNetV2TVMI4(model_config_e)
215 # decoder setup
216 model = FPNLitePixel2PixelASPP(base_model, model_config)
218 num_inputs = len(model_config.input_channels)
219 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
220 if num_inputs > 1:
221 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
222 '^classifier.': 'encoder.classifier.',
223 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
224 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
225 else:
226 change_names_dict = {'^features.': 'encoder.features.',
227 '^classifier.': 'encoder.classifier.',
228 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
229 #
231 if pretrained:
232 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
233 #
234 return model, change_names_dict
237 # fast down sampling model (encoder stride 64 model)
238 def fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
239 model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
240 model_config.fastdown = True
241 model_config.strides = (2,2,2,2,2)
242 model_config.shortcut_strides = (8,16,32,64)
243 model_config.shortcut_channels = (24,32,96,320)
244 model_config.decoder_chan = 256
245 model_config.aspp_chan = 256
246 return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
249 # fast down sampling model (encoder stride 64 model) with fpn decoder channels 128
250 def fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
251 model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
252 model_config.fastdown = True
253 model_config.strides = (2,2,2,2,2)
254 model_config.shortcut_strides = (4,8,16,32,64)
255 model_config.shortcut_channels = (16,24,32,96,320)
256 model_config.decoder_chan = 128
257 model_config.aspp_chan = 128
258 return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
261 ##################
262 # similar to the original fpn model with extra convolutions with strides (no aspp)
263 def fpnlite_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
264 model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
265 model_config.use_aspp = False
266 model_config.use_extra_strides = True
267 model_config.shortcut_strides = (4, 8, 16, 32, 64, 128)
268 model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
269 return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
272 # similar to the original fpn model with extra convolutions with strides (no aspp) - fast down sampling model (encoder stride 64 model)
273 def fpnlite_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
274 model_config = get_config_fpnlitep2p_mnv2().merge_from(model_config)
275 model_config.use_aspp = False
276 model_config.use_extra_strides = True
277 model_config.fastdown = True
278 model_config.strides = (2,2,2,2,2)
279 model_config.shortcut_strides = (8, 16, 32, 64, 128, 256)
280 model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
281 model_config.decoder_chan = 256
282 model_config.aspp_chan = 256
283 return fpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
286 ###########################################
287 def get_config_fpnlitep2p_resnet50():
288 # only the delta compared to the one defined for mobilenetv2
289 model_config = get_config_fpnlitep2p_mnv2()
290 model_config.shortcut_channels = (256,512,1024,2048)
291 return model_config
294 def fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
295 model_config = get_config_fpnlitep2p_resnet50().merge_from(model_config)
296 # encoder setup
297 model_config_e = model_config.clone()
298 base_model = ResNet50MI4(model_config_e)
299 # decoder setup
300 model = FPNLitePixel2PixelASPP(base_model, model_config)
302 # the pretrained model provided by torchvision and what is defined here differs slightly
303 # note: that this change_names_dict will take effect only if the direct load fails
304 # finally take care of the change for fpn (features->encoder.features)
305 num_inputs = len(model_config.input_channels)
306 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
307 if num_inputs > 1:
308 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
309 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
310 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
311 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
312 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
313 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
314 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
315 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
316 else:
317 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
318 '^bn1.': 'encoder.features.bn1.',
319 '^relu.': 'encoder.features.relu.',
320 '^maxpool.': 'encoder.features.maxpool.',
321 '^layer': 'encoder.features.layer',
322 '^features.': 'encoder.features.',
323 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
324 #
326 if pretrained:
327 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
329 return model, change_names_dict
332 def fpnlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
333 model_config = get_config_fpnlitep2p_resnet50().merge_from(model_config)
334 model_config.fastdown = True
335 model_config.strides = (2,2,2,2,2)
336 model_config.shortcut_strides = (8,16,32,64) #(4,8,16,32,64)
337 model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
338 model_config.decoder_chan = 256 #128
339 model_config.aspp_chan = 256 #128
340 return fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)