]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/bifpnlite_pixel2pixel.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / bifpnlite_pixel2pixel.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 #   list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 #   this list of conditions and the following disclaimer in the documentation
13 #   and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 #   contributors may be used to endorse or promote products derived from
17 #   this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
33 # Our implementation of BiFPN-Lite (i.e. without the weighting before adding tensors):
34 # Reference:
35 # EfficientDet: Scalable and Efficient Object Detection
36 # Mingxing Tan, Ruoming Pang, Quoc V. Le,
37 # Google Research, Brain Team
38 # https://arxiv.org/pdf/1911.09070.pdf
40 import copy
41 import torch
42 import numpy as np
43 from .... import xnn
45 from .pixel2pixelnet import *
46 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, \
47     RegNetX400MFMI4, RegNetX800MFMI4, RegNetX1p6GFMI4, RegNetX3p2GFMI4
50 __all__ = ['BiFPNLitePixel2PixelASPP', 'BiFPNLitePixel2PixelDecoder',
51            'bifpnlite_pixel2pixel_aspp_mobilenetv2_tv', 'bifpnlite_pixel2pixel_mobilenetv2_tv',
52            'bifpnlite_pixel2pixel_aspp_regnetx400mf', 'bifpnlite_pixel2pixel_aspp_regnetx400mf_bgr',
53            'bifpnlite_pixel2pixel_aspp_regnetx800mf', 'bifpnlite_pixel2pixel_aspp_regnetx800mf_bgr',
54            'bifpnlite_pixel2pixel_aspp_regnetx1p6gf', 'bifpnlite_pixel2pixel_aspp_regnetx1p6gf_bgr',
55            'bifpnlite_pixel2pixel_aspp_regnetx3p2gf', 'bifpnlite_pixel2pixel_aspp_regnetx3p2gf_bgr'
56            ]
58 # config settings for mobilenetv2 backbone
59 def get_config_bifpnlitep2p_mnv2():
60     model_config = xnn.utils.ConfigNode()
61     model_config.num_classes = None
62     model_config.num_decoders = None
63     model_config.intermediate_outputs = True
64     model_config.use_aspp = True
65     model_config.use_extra_strides = False
66     model_config.groupwise_sep = False
67     model_config.fastdown = False
68     model_config.width_mult = 1.0
69     model_config.target_input_ratio = 1
71     model_config.num_bifpn_blocks = 4
72     model_config.num_head_blocks = 0 #1
73     model_config.num_fpn_outs = 6
74     model_config.strides = (2,2,2,2,2)
75     encoder_stride = np.prod(model_config.strides)
76     model_config.shortcut_strides = (4,8,16,encoder_stride)
77     # this is for mobilenetv2 - change for other networks
78     model_config.shortcut_channels = (24,32,96,320)
80     model_config.aspp_dil = (6,12,18)
81     model_config.decoder_chan = 128 #256
82     model_config.aspp_chan = 128    #256
83     model_config.fpn_chan = 128     #256
84     model_config.head_chan = 128    #256
86     model_config.kernel_size_smooth = 3
87     model_config.interpolation_type = 'upsample'
88     model_config.interpolation_mode = 'bilinear'
90     model_config.final_prediction = True
91     model_config.final_upsample = True
92     model_config.output_range = None
94     model_config.normalize_input = False
95     model_config.split_outputs = False
96     model_config.decoder_factor = 1.0
97     model_config.activation = xnn.layers.DefaultAct2d
98     model_config.linear_dw = False
99     model_config.normalize_gradients = False
100     model_config.freeze_encoder = False
101     model_config.freeze_decoder = False
102     model_config.multi_task = False
103     return model_config
106 ###########################################
107 class BiFPNLite(torch.nn.Module):
108     def __init__(self, model_config, in_channels=None, intermediate_channels=None, out_channels=None,
109                  num_outs=5, num_blocks=None, add_extra_convs = 'on_output',
110                  group_size_dw=None, normalization=None, activation=None, **kwargs):
111         super().__init__()
112         self.model_config = model_config
113         self.add_extra_convs = add_extra_convs
114         self.in_channels = in_channels
115         self.intermediate_channels = intermediate_channels
116         self.out_channels = out_channels
117         self.num_outs = num_outs
118         self.num_fpn_outs = num_outs # for now set them to be same - but can be different
120         blocks = []
121         for i in range(num_blocks):
122             last_in_channels = [intermediate_channels for _ in range(self.num_fpn_outs)] if i>0 else in_channels
123             if i < (num_blocks-1):
124                 # the initial bifpn blocks can operate with fewer number of channels
125                 block_id = i
126                 bi_fpn = BiFPNLiteBlock(block_id=block_id, in_channels=in_channels, out_channels=intermediate_channels, num_outs=self.num_fpn_outs,
127                                 group_size_dw=group_size_dw, normalization=normalization, activation=activation, **kwargs)
128             else:
129                 # block_id=0 will cause shortcut connections to be created to change the number of channels
130                 block_id = 0 if ((num_blocks == 1) or (out_channels != intermediate_channels)) else i
131                 # for segmentation, last one doesn't need down blocks as they are not used.
132                 bi_fpn = BiFPNLiteBlock(block_id=block_id, up_only=True, in_channels=last_in_channels, out_channels=out_channels,
133                                         num_outs=self.num_fpn_outs, group_size_dw=group_size_dw, normalization=normalization,
134                                         activation=activation, **kwargs)
135             #
136             blocks.append(bi_fpn)
137         #
138         self.bifpn_blocks = torch.nn.Sequential(*blocks)
140         NormType1 = normalization[-1] if isinstance(normalization,(list,tuple)) else normalization
141         self.extra_convs = torch.nn.ModuleList()
142         if self.num_outs > self.num_fpn_outs:
143             in_ch = self.in_channels[-1] if self.add_extra_convs == 'on_input' else self.out_channels
144             DownsampleType = torch.nn.MaxPool2d
145             for i in range(self.num_outs-self.num_fpn_outs):
146                 extra_conv = BiFPNLiteBlock.build_downsample_module(in_ch, self.out_channels, kernel_size=3, stride=2,
147                                     group_size_dw=group_size_dw, normalization=NormType1, activation=None,
148                                     DownsampleType=DownsampleType)
149                 self.extra_convs.append(extra_conv)
150                 in_ch = self.out_channels
151             #
152         #
154     def forward(self, x_input, x_list):
155         in_shape = x_input.shape
156         inputs = []
157         for s_chan, s_stride in zip(self.in_channels, self.model_config.shortcut_strides):
158             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
159             shape_s[1] = s_chan
160             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
161             inputs.append(x_s)
162         #
164         assert len(inputs) == len(self.in_channels)
165         outputs = self.bifpn_blocks(inputs)
166         outputs = list(outputs)
167         if self.num_outs > self.num_fpn_outs:
168             inp = inputs[-1] if self.add_extra_convs == 'on_input' else outputs[-1]
169             for i in range(self.num_outs-self.num_fpn_outs):
170                 extra_inp = self.extra_convs[i](inp)
171                 outputs.append(extra_inp)
172                 inp = extra_inp
173             #
174         #
175         return outputs
178 class BiFPNLiteBlock(torch.nn.Module):
179     def __init__(self, block_id=None, in_channels=None, out_channels=None, num_outs=None, up_only=False, start_level=0, end_level=-1,
180                  add_extra_convs=None, group_size_dw=None, normalization=xnn.layers.DefaultNorm2d, activation=xnn.layers.DefaultAct2d):
181         super(BiFPNLiteBlock, self).__init__()
182         assert isinstance(in_channels, (list,tuple))
183         self.up_only = up_only
184         self.in_channels = in_channels
185         self.out_channels = out_channels
186         self.num_ins = len(in_channels)
187         self.num_outs = num_outs
189         if end_level == -1:
190             self.backbone_end_level = self.num_ins
191             assert num_outs >= self.num_ins - start_level
192         else:
193             # if end_level < inputs, no extra level is allowed
194             self.backbone_end_level = end_level
195             assert end_level <= len(in_channels)
196             assert num_outs == end_level - start_level
197         self.start_level = start_level
198         self.end_level = end_level
199         self.add_extra_convs = add_extra_convs
200         self.block_id = block_id
201         assert block_id is not None, f'block_id must be valid: {block_id}'
203         NormType = normalization
204         ActType = activation
205         DownsampleType = torch.nn.MaxPool2d
206         UpsampleType = xnn.layers.ResizeWith
207         ConvModuleWrapper = xnn.layers.ConvDWSepNormAct2d
208         NormType1 = normalization[-1] if isinstance(normalization,(list,tuple)) else normalization
209         ActType1 = activation[-1] if isinstance(activation,(list,tuple)) else activation
210         upsample_cfg = dict(scale_factor=2, mode='bilinear')
212         # add extra conv layers (e.g., RetinaNet)
213         if block_id == 0:
214             self.num_backbone_convs = (self.backbone_end_level - self.start_level)
215             self.extra_levels = num_outs - self.num_backbone_convs
216             self.in_convs = torch.nn.ModuleList()
217             for i in range(num_outs):
218                 if i < self.num_backbone_convs:
219                     in_ch = in_channels[self.start_level + i]
220                 elif i == self.num_backbone_convs:
221                     in_ch = in_channels[-1]
222                 else:
223                     in_ch = out_channels
224                 #
225                 stride = 1 if i < self.num_backbone_convs else 2
226                 in_conv = BiFPNLiteBlock.build_downsample_module(in_ch, out_channels, kernel_size=3, stride=stride,
227                                    group_size_dw=group_size_dw, normalization=NormType1, activation=None,
228                                    DownsampleType=DownsampleType)
229                 self.in_convs.append(in_conv)
230             #
231         #
233         self.ups = torch.nn.ModuleList()
234         self.up_convs = torch.nn.ModuleList()
235         self.up_acts = torch.nn.ModuleList()
236         self.up_adds = torch.nn.ModuleList()
237         if not up_only:
238             self.downs = torch.nn.ModuleList()
239             self.down_convs = torch.nn.ModuleList()
240             self.down_acts = torch.nn.ModuleList()
241             self.down_adds1 = torch.nn.ModuleList()
242             self.down_adds2 = torch.nn.ModuleList()
243         #
244         for i in range(self.num_outs-1):
245             # up modules
246             up = UpsampleType(**upsample_cfg)
247             up_conv = ConvModuleWrapper(out_channels, out_channels, 3, padding=1,
248                     group_size_dw=group_size_dw, normalization=NormType, activation=ActType)
249             up_act = ActType1()
250             self.ups.append(up)
251             self.up_convs.append(up_conv)
252             self.up_acts.append(up_act)
253             self.up_adds.append(xnn.layers.AddBlock())
254             # down modules
255             if not up_only:
256                 down = DownsampleType(kernel_size=3, stride=2, padding=1)
257                 down_conv = ConvModuleWrapper(out_channels, out_channels, 3, padding=1,
258                     group_size_dw=group_size_dw, normalization=NormType, activation=ActType)
259                 down_act = ActType1()
260                 self.downs.append(down)
261                 self.down_convs.append(down_conv)
262                 self.down_acts.append(down_act)
263                 self.down_adds1.append(xnn.layers.AddBlock())
264                 self.down_adds2.append(xnn.layers.AddBlock())
265             #
267     def forward(self, inputs):
268         # in convs
269         if self.block_id == 0:
270             ins = [self.in_convs[i](inputs[self.start_level+i]) for i in range(self.num_backbone_convs)]
271             extra_in = inputs[-1]
272             for i in range(self.num_backbone_convs, self.num_outs):
273                 extra_in = self.in_convs[i](extra_in)
274                 ins.append(extra_in)
275             #
276         else:
277             ins = inputs
278         #
279         # up convs
280         ups = [None] * self.num_outs
281         ups[-1] = ins[-1]
282         for i in range(self.num_outs-2, -1, -1):
283             add_block = self.up_adds[i]
284             ups[i] = self.up_convs[i](self.up_acts[i](
285                     add_block((ins[i], self.ups[i](ups[i+1])))
286             ))
287         #
288         if self.up_only:
289             return tuple(ups)
290         else:
291             # down convs
292             outs = [None] * self.num_outs
293             outs[0] = ups[0]
294             for i in range(0, self.num_outs-1):
295                 add_block1 = self.down_adds1[i]
296                 res = add_block1((ins[i+1], ups[i+1])) if (ins[i+1] is not ups[i+1]) else ins[i+1]
297                 add_block2 = self.down_adds2[i]
298                 outs[i+1] = self.down_convs[i](self.down_acts[i](
299                     add_block2((res,self.downs[i](outs[i])))
300                 ))
301             #
302             return tuple(outs)
304     @staticmethod
305     def build_downsample_module(in_channels, out_channels, kernel_size, stride,
306                                 group_size_dw=None, normalization=None, activation=None,
307                                 DownsampleType=None):
308         NormType = normalization
309         ActType = activation
310         ConvModuleWrapper = xnn.layers.ConvNormAct2d
311         padding = kernel_size//2
312         if in_channels == out_channels and stride == 1:
313             block = ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=1,
314                                 padding=0, normalization=NormType, activation=ActType)
315         elif in_channels == out_channels and stride > 1:
316             block = DownsampleType(kernel_size=kernel_size, stride=stride, padding=padding)
317         elif in_channels != out_channels and stride == 1:
318             block = ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=stride,
319                                 padding=0, normalization=NormType, activation=ActType)
320         else:
321             block = torch.nn.Sequential(
322                     DownsampleType(kernel_size=kernel_size, stride=stride, padding=padding),
323                     ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=1,
324                                 padding=0, normalization=NormType, activation=ActType))
325         #
326         return block
329 ###########################################
330 class BiFPNLitePixel2PixelDecoder(torch.nn.Module):
331     def __init__(self, model_config):
332         super().__init__()
333         self.model_config = model_config
334         normalization = xnn.layers.DefaultNorm2d
335         normalization_dws = (normalization, normalization)
336         activation = xnn.layers.DefaultAct2d
337         activation_dws = (activation, activation)
338         self.output_type = model_config.output_type
339         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
340         group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
342         self.rfblock = None
343         if self.model_config.use_aspp:
344             current_channels = self.model_config.shortcut_channels[-1]
345             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
346             self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels,
347                                 dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation,
348                                 group_size_dw=group_size_dw)
349             current_channels = decoder_channels
350         elif self.model_config.use_extra_strides:
351             # a low complexity pyramid
352             current_channels = self.model_config.shortcut_channels[-3]
353             self.rfblock = torch.nn.Sequential(
354                 xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3,
355                         stride=2, activation=(activation, activation), group_size_dw=group_size_dw),
356                 xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3,
357                         stride=2, activation=(activation, activation), group_size_dw=group_size_dw))
358             current_channels = decoder_channels
359         else:
360             current_channels = self.model_config.shortcut_channels[-1]
361             self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
362             current_channels = decoder_channels
363         #
365         shortcut_strides = self.model_config.shortcut_strides
366         fpn_in_channels = list(copy.deepcopy(self.model_config.shortcut_channels))
367         fpn_in_channels[-1] = current_channels
368         fpn_channels = round(self.model_config.fpn_chan*self.model_config.decoder_factor)
369         out_channels = round(self.model_config.head_chan*self.model_config.decoder_factor)
370         num_fpn_outs = self.model_config.num_fpn_outs
371         num_bifpn_blocks = self.model_config.num_bifpn_blocks
372         group_size_dw = self.model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
373         self.fpn = BiFPNLite(self.model_config, in_channels=fpn_in_channels, intermediate_channels=fpn_channels,
374                         out_channels=out_channels, num_outs=num_fpn_outs, num_blocks=num_bifpn_blocks,
375                         group_size_dw=group_size_dw, normalization=normalization_dws, activation=activation_dws)
377         head = []
378         current_channels = out_channels
379         if self.model_config.num_head_blocks > 0:
380             for h_idx in range(self.model_config.num_head_blocks):
381                 hblock = xnn.layers.ConvDWSepNormAct2d(current_channels, out_channels, kernel_size=3, stride=1,
382                                                        group_size_dw=group_size_dw, normalization=normalization_dws,
383                                                        activation=activation_dws)
384                 current_channels = out_channels
385                 head.append(hblock)
386             #
387             self.head = torch.nn.Sequential(*head)
388         else:
389             self.head = None
390         #
392         # add prediction & upsample modules
393         if self.model_config.final_prediction:
394             add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
395         #
397     def forward(self, x_input, x, x_list):
398         assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
399         assert x is x_list[-1], 'the features must the last one in x_list'
400         x_input = x_input[0]
401         in_shape = x_input.shape
403         if self.model_config.use_extra_strides:
404             for blk in self.rfblock:
405                 x = blk(x)
406                 x_list += [x]
407             #
408         elif self.rfblock is not None:
409             x = self.rfblock(x)
410             x_list[-1] = x
411         #
413         x_list = self.fpn(x_input, x_list)
414         x = x_list[0]
416         x = self.head(x) if (self.head is not None) else x
418         if self.model_config.final_prediction:
419             # prediction
420             x = self.pred(x)
422             # final prediction is the upsampled one
423             if self.model_config.final_upsample:
424                 x = self.upsample(x)
426             if (not self.training) and (self.output_type == 'segmentation'):
427                 x = torch.argmax(x, dim=1, keepdim=True)
429             assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
430         #
431         return x
434 ###########################################
435 class BiFPNLitePixel2PixelASPP(Pixel2PixelNet):
436     def __init__(self, base_model, model_config):
437         super().__init__(base_model, BiFPNLitePixel2PixelDecoder, model_config)
440 ###########################################
441 def bifpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
442     model_config = get_config_bifpnlitep2p_mnv2().merge_from(model_config)
443     # encoder setup
444     model_config_e = model_config.clone()
445     base_model = MobileNetV2TVMI4(model_config_e)
446     # decoder setup
447     model = BiFPNLitePixel2PixelASPP(base_model, model_config)
449     num_inputs = len(model_config.input_channels)
450     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
451     if num_inputs > 1:
452         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
453                             '^classifier.': 'encoder.classifier.',
454                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
455                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
456     else:
457         change_names_dict = {'^features.': 'encoder.features.',
458                              '^classifier.': 'encoder.classifier.',
459                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
460     #
462     if pretrained:
463         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
464     #
465     return model, change_names_dict
469 ##################
470 # similar to the original fpn model with extra convolutions with strides (no aspp)
471 def bifpnlite_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
472     model_config = get_config_bifpnlitep2p_mnv2().merge_from(model_config)
473     model_config.use_aspp = False
474     return bifpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
477 ###########################################
478 # here this is nothing specific about bgr in this model
479 # but is just a reminder that regnet models are typically trained with bgr input
480 def bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained=None, base_model_class=None):
481     # encoder setup
482     model_config_e = model_config.clone()
483     base_model = base_model_class(model_config_e)
484     # decoder setup
485     model = BiFPNLitePixel2PixelASPP(base_model, model_config)
487     # the pretrained model provided by torchvision and what is defined here differs slightly
488     # note: that this change_names_dict  will take effect only if the direct load fails
489     # finally take care of the change for deeplabv3lite (features->encoder.features)
490     num_inputs = len(model_config.input_channels)
491     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
492     if num_inputs > 1:
493         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
494                              '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
495                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
496     else:
497         change_names_dict = {'^stem.': 'encoder.features.stem.',
498                              '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
499                              '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
500                              '^features.': 'encoder.features.',
501                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
502     #
504     if pretrained:
505         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
506                                        state_dict_name=['state_dict','model_state'])
507     else:
508         # need to use state_dict_name as the checkpoint uses a different name for state_dict
509         # provide a custom load_weighs for the model
510         def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
511                                        state_dict_name=['state_dict','model_state']):
512             xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
513                                            state_dict_name=state_dict_name)
514         #
515         model.load_weights = load_weights_func
517     return model, change_names_dict
520 ###########################################
521 # config settings for mobilenetv2 backbone
522 def get_config_bifpnlite_regnetx400mf():
523     # only the delta compared to the one defined for mobilenetv2
524     model_config = get_config_bifpnlitep2p_mnv2()
525     model_config.group_size_dw = 16
526     model_config.shortcut_channels = (32,64,160,384)
527     return model_config
529 def bifpnlite_pixel2pixel_aspp_regnetx400mf(model_config, pretrained=None):
530     model_config = get_config_bifpnlite_regnetx400mf().merge_from(model_config)
531     return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX400MFMI4)
534 bifpnlite_pixel2pixel_aspp_regnetx400mf_bgr = bifpnlite_pixel2pixel_aspp_regnetx400mf
537 ###########################################
538 # config settings for mobilenetv2 backbone
539 def get_config_bifpnlite_regnetx800mf():
540     # only the delta compared to the one defined for mobilenetv2
541     model_config = get_config_bifpnlitep2p_mnv2()
542     model_config.group_size_dw = 16
543     model_config.shortcut_channels = (64,128,288,672)
544     return model_config
546 def bifpnlite_pixel2pixel_aspp_regnetx800mf(model_config, pretrained=None):
547     model_config = get_config_bifpnlite_regnetx800mf().merge_from(model_config)
548     return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX800MFMI4)
551 bifpnlite_pixel2pixel_aspp_regnetx800mf_bgr = bifpnlite_pixel2pixel_aspp_regnetx800mf
554 ###########################################
555 # config settings for mobilenetv2 backbone
556 def get_config_bifpnlite_regnetx1p6gf():
557     # only the delta compared to the one defined for mobilenetv2
558     model_config = get_config_bifpnlitep2p_mnv2()
559     # group size is 24. make the decoder channels multiples of 24
560     model_config.group_size_dw = 24
561     model_config.decoder_chan = 168 #264
562     model_config.aspp_chan = 168    #264
563     model_config.fpn_chan = 168     #264
564     model_config.head_chan = 168    #264
565     model_config.shortcut_channels = (72, 168, 408, 912)
566     return model_config
569 def bifpnlite_pixel2pixel_aspp_regnetx1p6gf(model_config, pretrained=None):
570     model_config = get_config_bifpnlite_regnetx1p6gf().merge_from(model_config)
571     return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX1p6GFMI4)
574 bifpnlite_pixel2pixel_aspp_regnetx1p6gf_bgr = bifpnlite_pixel2pixel_aspp_regnetx1p6gf
577 ###########################################
578 # config settings for mobilenetv2 backbone
579 def get_config_bifpnlite_regnetx3p2gf():
580     # only the delta compared to the one defined for mobilenetv2
581     model_config = get_config_bifpnlitep2p_mnv2()
582     # group size is 48. make the decoder channels multiples of 48
583     model_config.group_size_dw = 48
584     model_config.decoder_chan = 192 #288
585     model_config.aspp_chan = 192    #288
586     model_config.fpn_chan = 192     #288
587     model_config.head_chan = 192    #288
588     model_config.shortcut_channels = (96, 192, 432, 1008)
589     return model_config
592 def bifpnlite_pixel2pixel_aspp_regnetx3p2gf(model_config, pretrained=None):
593     model_config = get_config_bifpnlite_regnetx3p2gf().merge_from(model_config)
594     return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX3p2GFMI4)
597 bifpnlite_pixel2pixel_aspp_regnetx3p2gf_bgr = bifpnlite_pixel2pixel_aspp_regnetx3p2gf