[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / bifpnlite_pixel2pixel.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 # list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 # this list of conditions and the following disclaimer in the documentation
13 # and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
33 # Our implementation of BiFPN-Lite (i.e. without the weighting before adding tensors):
34 # Reference:
35 # EfficientDet: Scalable and Efficient Object Detection
36 # Mingxing Tan, Ruoming Pang, Quoc V. Le,
37 # Google Research, Brain Team
38 # https://arxiv.org/pdf/1911.09070.pdf
40 import copy
41 import torch
42 import numpy as np
43 from .... import xnn
45 from .pixel2pixelnet import *
46 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, \
47 RegNetX400MFMI4, RegNetX800MFMI4, RegNetX1p6GFMI4, RegNetX3p2GFMI4
50 __all__ = ['BiFPNLitePixel2PixelASPP', 'BiFPNLitePixel2PixelDecoder',
51 'bifpnlite_pixel2pixel_aspp_mobilenetv2_tv', 'bifpnlite_pixel2pixel_mobilenetv2_tv',
52 'bifpnlite_pixel2pixel_aspp_regnetx400mf', 'bifpnlite_pixel2pixel_aspp_regnetx400mf_bgr',
53 'bifpnlite_pixel2pixel_aspp_regnetx800mf', 'bifpnlite_pixel2pixel_aspp_regnetx800mf_bgr',
54 'bifpnlite_pixel2pixel_aspp_regnetx1p6gf', 'bifpnlite_pixel2pixel_aspp_regnetx1p6gf_bgr',
55 'bifpnlite_pixel2pixel_aspp_regnetx3p2gf', 'bifpnlite_pixel2pixel_aspp_regnetx3p2gf_bgr'
56 ]
58 # config settings for mobilenetv2 backbone
59 def get_config_bifpnlitep2p_mnv2():
60 model_config = xnn.utils.ConfigNode()
61 model_config.num_classes = None
62 model_config.num_decoders = None
63 model_config.intermediate_outputs = True
64 model_config.use_aspp = True
65 model_config.use_extra_strides = False
66 model_config.groupwise_sep = False
67 model_config.fastdown = False
68 model_config.width_mult = 1.0
69 model_config.target_input_ratio = 1
71 model_config.num_bifpn_blocks = 4
72 model_config.num_head_blocks = 0 #1
73 model_config.num_fpn_outs = 6
74 model_config.strides = (2,2,2,2,2)
75 encoder_stride = np.prod(model_config.strides)
76 model_config.shortcut_strides = (4,8,16,encoder_stride)
77 # this is for mobilenetv2 - change for other networks
78 model_config.shortcut_channels = (24,32,96,320)
80 model_config.aspp_dil = (6,12,18)
81 model_config.decoder_chan = 128 #256
82 model_config.aspp_chan = 128 #256
83 model_config.fpn_chan = 128 #256
84 model_config.head_chan = 128 #256
86 model_config.kernel_size_smooth = 3
87 model_config.interpolation_type = 'upsample'
88 model_config.interpolation_mode = 'bilinear'
90 model_config.final_prediction = True
91 model_config.final_upsample = True
92 model_config.output_range = None
94 model_config.normalize_input = False
95 model_config.split_outputs = False
96 model_config.decoder_factor = 1.0
97 model_config.activation = xnn.layers.DefaultAct2d
98 model_config.linear_dw = False
99 model_config.normalize_gradients = False
100 model_config.freeze_encoder = False
101 model_config.freeze_decoder = False
102 model_config.multi_task = False
103 return model_config
106 ###########################################
107 class BiFPNLite(torch.nn.Module):
108 def __init__(self, model_config, in_channels=None, intermediate_channels=None, out_channels=None,
109 num_outs=5, num_blocks=None, add_extra_convs = 'on_output',
110 group_size_dw=None, normalization=None, activation=None, **kwargs):
111 super().__init__()
112 self.model_config = model_config
113 self.add_extra_convs = add_extra_convs
114 self.in_channels = in_channels
115 self.intermediate_channels = intermediate_channels
116 self.out_channels = out_channels
117 self.num_outs = num_outs
118 self.num_fpn_outs = num_outs # for now set them to be same - but can be different
120 blocks = []
121 for i in range(num_blocks):
122 last_in_channels = [intermediate_channels for _ in range(self.num_fpn_outs)] if i>0 else in_channels
123 if i < (num_blocks-1):
124 # the initial bifpn blocks can operate with fewer number of channels
125 block_id = i
126 bi_fpn = BiFPNLiteBlock(block_id=block_id, in_channels=in_channels, out_channels=intermediate_channels, num_outs=self.num_fpn_outs,
127 group_size_dw=group_size_dw, normalization=normalization, activation=activation, **kwargs)
128 else:
129 # block_id=0 will cause shortcut connections to be created to change the number of channels
130 block_id = 0 if ((num_blocks == 1) or (out_channels != intermediate_channels)) else i
131 # for segmentation, last one doesn't need down blocks as they are not used.
132 bi_fpn = BiFPNLiteBlock(block_id=block_id, up_only=True, in_channels=last_in_channels, out_channels=out_channels,
133 num_outs=self.num_fpn_outs, group_size_dw=group_size_dw, normalization=normalization,
134 activation=activation, **kwargs)
135 #
136 blocks.append(bi_fpn)
137 #
138 self.bifpn_blocks = torch.nn.Sequential(*blocks)
140 NormType1 = normalization[-1] if isinstance(normalization,(list,tuple)) else normalization
141 self.extra_convs = torch.nn.ModuleList()
142 if self.num_outs > self.num_fpn_outs:
143 in_ch = self.in_channels[-1] if self.add_extra_convs == 'on_input' else self.out_channels
144 DownsampleType = torch.nn.MaxPool2d
145 for i in range(self.num_outs-self.num_fpn_outs):
146 extra_conv = BiFPNLiteBlock.build_downsample_module(in_ch, self.out_channels, kernel_size=3, stride=2,
147 group_size_dw=group_size_dw, normalization=NormType1, activation=None,
148 DownsampleType=DownsampleType)
149 self.extra_convs.append(extra_conv)
150 in_ch = self.out_channels
151 #
152 #
154 def forward(self, x_input, x_list):
155 in_shape = x_input.shape
156 inputs = []
157 for s_chan, s_stride in zip(self.in_channels, self.model_config.shortcut_strides):
158 shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
159 shape_s[1] = s_chan
160 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
161 inputs.append(x_s)
162 #
164 assert len(inputs) == len(self.in_channels)
165 outputs = self.bifpn_blocks(inputs)
166 outputs = list(outputs)
167 if self.num_outs > self.num_fpn_outs:
168 inp = inputs[-1] if self.add_extra_convs == 'on_input' else outputs[-1]
169 for i in range(self.num_outs-self.num_fpn_outs):
170 extra_inp = self.extra_convs[i](inp)
171 outputs.append(extra_inp)
172 inp = extra_inp
173 #
174 #
175 return outputs
178 class BiFPNLiteBlock(torch.nn.Module):
179 def __init__(self, block_id=None, in_channels=None, out_channels=None, num_outs=None, up_only=False, start_level=0, end_level=-1,
180 add_extra_convs=None, group_size_dw=None, normalization=xnn.layers.DefaultNorm2d, activation=xnn.layers.DefaultAct2d):
181 super(BiFPNLiteBlock, self).__init__()
182 assert isinstance(in_channels, (list,tuple))
183 self.up_only = up_only
184 self.in_channels = in_channels
185 self.out_channels = out_channels
186 self.num_ins = len(in_channels)
187 self.num_outs = num_outs
189 if end_level == -1:
190 self.backbone_end_level = self.num_ins
191 assert num_outs >= self.num_ins - start_level
192 else:
193 # if end_level < inputs, no extra level is allowed
194 self.backbone_end_level = end_level
195 assert end_level <= len(in_channels)
196 assert num_outs == end_level - start_level
197 self.start_level = start_level
198 self.end_level = end_level
199 self.add_extra_convs = add_extra_convs
200 self.block_id = block_id
201 assert block_id is not None, f'block_id must be valid: {block_id}'
203 NormType = normalization
204 ActType = activation
205 DownsampleType = torch.nn.MaxPool2d
206 UpsampleType = xnn.layers.ResizeWith
207 ConvModuleWrapper = xnn.layers.ConvDWSepNormAct2d
208 NormType1 = normalization[-1] if isinstance(normalization,(list,tuple)) else normalization
209 ActType1 = activation[-1] if isinstance(activation,(list,tuple)) else activation
210 upsample_cfg = dict(scale_factor=2, mode='bilinear')
212 # add extra conv layers (e.g., RetinaNet)
213 if block_id == 0:
214 self.num_backbone_convs = (self.backbone_end_level - self.start_level)
215 self.extra_levels = num_outs - self.num_backbone_convs
216 self.in_convs = torch.nn.ModuleList()
217 for i in range(num_outs):
218 if i < self.num_backbone_convs:
219 in_ch = in_channels[self.start_level + i]
220 elif i == self.num_backbone_convs:
221 in_ch = in_channels[-1]
222 else:
223 in_ch = out_channels
224 #
225 stride = 1 if i < self.num_backbone_convs else 2
226 in_conv = BiFPNLiteBlock.build_downsample_module(in_ch, out_channels, kernel_size=3, stride=stride,
227 group_size_dw=group_size_dw, normalization=NormType1, activation=None,
228 DownsampleType=DownsampleType)
229 self.in_convs.append(in_conv)
230 #
231 #
233 self.ups = torch.nn.ModuleList()
234 self.up_convs = torch.nn.ModuleList()
235 self.up_acts = torch.nn.ModuleList()
236 self.up_adds = torch.nn.ModuleList()
237 if not up_only:
238 self.downs = torch.nn.ModuleList()
239 self.down_convs = torch.nn.ModuleList()
240 self.down_acts = torch.nn.ModuleList()
241 self.down_adds1 = torch.nn.ModuleList()
242 self.down_adds2 = torch.nn.ModuleList()
243 #
244 for i in range(self.num_outs-1):
245 # up modules
246 up = UpsampleType(**upsample_cfg)
247 up_conv = ConvModuleWrapper(out_channels, out_channels, 3, padding=1,
248 group_size_dw=group_size_dw, normalization=NormType, activation=ActType)
249 up_act = ActType1()
250 self.ups.append(up)
251 self.up_convs.append(up_conv)
252 self.up_acts.append(up_act)
253 self.up_adds.append(xnn.layers.AddBlock())
254 # down modules
255 if not up_only:
256 down = DownsampleType(kernel_size=3, stride=2, padding=1)
257 down_conv = ConvModuleWrapper(out_channels, out_channels, 3, padding=1,
258 group_size_dw=group_size_dw, normalization=NormType, activation=ActType)
259 down_act = ActType1()
260 self.downs.append(down)
261 self.down_convs.append(down_conv)
262 self.down_acts.append(down_act)
263 self.down_adds1.append(xnn.layers.AddBlock())
264 self.down_adds2.append(xnn.layers.AddBlock())
265 #
267 def forward(self, inputs):
268 # in convs
269 if self.block_id == 0:
270 ins = [self.in_convs[i](inputs[self.start_level+i]) for i in range(self.num_backbone_convs)]
271 extra_in = inputs[-1]
272 for i in range(self.num_backbone_convs, self.num_outs):
273 extra_in = self.in_convs[i](extra_in)
274 ins.append(extra_in)
275 #
276 else:
277 ins = inputs
278 #
279 # up convs
280 ups = [None] * self.num_outs
281 ups[-1] = ins[-1]
282 for i in range(self.num_outs-2, -1, -1):
283 add_block = self.up_adds[i]
284 ups[i] = self.up_convs[i](self.up_acts[i](
285 add_block((ins[i], self.ups[i](ups[i+1])))
286 ))
287 #
288 if self.up_only:
289 return tuple(ups)
290 else:
291 # down convs
292 outs = [None] * self.num_outs
293 outs[0] = ups[0]
294 for i in range(0, self.num_outs-1):
295 add_block1 = self.down_adds1[i]
296 res = add_block1((ins[i+1], ups[i+1])) if (ins[i+1] is not ups[i+1]) else ins[i+1]
297 add_block2 = self.down_adds2[i]
298 outs[i+1] = self.down_convs[i](self.down_acts[i](
299 add_block2((res,self.downs[i](outs[i])))
300 ))
301 #
302 return tuple(outs)
304 @staticmethod
305 def build_downsample_module(in_channels, out_channels, kernel_size, stride,
306 group_size_dw=None, normalization=None, activation=None,
307 DownsampleType=None):
308 NormType = normalization
309 ActType = activation
310 ConvModuleWrapper = xnn.layers.ConvNormAct2d
311 padding = kernel_size//2
312 if in_channels == out_channels and stride == 1:
313 block = ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=1,
314 padding=0, normalization=NormType, activation=ActType)
315 elif in_channels == out_channels and stride > 1:
316 block = DownsampleType(kernel_size=kernel_size, stride=stride, padding=padding)
317 elif in_channels != out_channels and stride == 1:
318 block = ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=stride,
319 padding=0, normalization=NormType, activation=ActType)
320 else:
321 block = torch.nn.Sequential(
322 DownsampleType(kernel_size=kernel_size, stride=stride, padding=padding),
323 ConvModuleWrapper(in_channels, out_channels, kernel_size=1, stride=1,
324 padding=0, normalization=NormType, activation=ActType))
325 #
326 return block
329 ###########################################
330 class BiFPNLitePixel2PixelDecoder(torch.nn.Module):
331 def __init__(self, model_config):
332 super().__init__()
333 self.model_config = model_config
334 normalization = xnn.layers.DefaultNorm2d
335 normalization_dws = (normalization, normalization)
336 activation = xnn.layers.DefaultAct2d
337 activation_dws = (activation, activation)
338 self.output_type = model_config.output_type
339 self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
340 group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
342 self.rfblock = None
343 if self.model_config.use_aspp:
344 current_channels = self.model_config.shortcut_channels[-1]
345 aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
346 self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels,
347 dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation,
348 group_size_dw=group_size_dw)
349 current_channels = decoder_channels
350 elif self.model_config.use_extra_strides:
351 # a low complexity pyramid
352 current_channels = self.model_config.shortcut_channels[-3]
353 self.rfblock = torch.nn.Sequential(
354 xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3,
355 stride=2, activation=(activation, activation), group_size_dw=group_size_dw),
356 xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3,
357 stride=2, activation=(activation, activation), group_size_dw=group_size_dw))
358 current_channels = decoder_channels
359 else:
360 current_channels = self.model_config.shortcut_channels[-1]
361 self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
362 current_channels = decoder_channels
363 #
365 shortcut_strides = self.model_config.shortcut_strides
366 fpn_in_channels = list(copy.deepcopy(self.model_config.shortcut_channels))
367 fpn_in_channels[-1] = current_channels
368 fpn_channels = round(self.model_config.fpn_chan*self.model_config.decoder_factor)
369 out_channels = round(self.model_config.head_chan*self.model_config.decoder_factor)
370 num_fpn_outs = self.model_config.num_fpn_outs
371 num_bifpn_blocks = self.model_config.num_bifpn_blocks
372 group_size_dw = self.model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
373 self.fpn = BiFPNLite(self.model_config, in_channels=fpn_in_channels, intermediate_channels=fpn_channels,
374 out_channels=out_channels, num_outs=num_fpn_outs, num_blocks=num_bifpn_blocks,
375 group_size_dw=group_size_dw, normalization=normalization_dws, activation=activation_dws)
377 head = []
378 current_channels = out_channels
379 if self.model_config.num_head_blocks > 0:
380 for h_idx in range(self.model_config.num_head_blocks):
381 hblock = xnn.layers.ConvDWSepNormAct2d(current_channels, out_channels, kernel_size=3, stride=1,
382 group_size_dw=group_size_dw, normalization=normalization_dws,
383 activation=activation_dws)
384 current_channels = out_channels
385 head.append(hblock)
386 #
387 self.head = torch.nn.Sequential(*head)
388 else:
389 self.head = None
390 #
392 # add prediction & upsample modules
393 if self.model_config.final_prediction:
394 add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
395 #
397 def forward(self, x_input, x, x_list):
398 assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
399 assert x is x_list[-1], 'the features must the last one in x_list'
400 x_input = x_input[0]
401 in_shape = x_input.shape
403 if self.model_config.use_extra_strides:
404 for blk in self.rfblock:
405 x = blk(x)
406 x_list += [x]
407 #
408 elif self.rfblock is not None:
409 x = self.rfblock(x)
410 x_list[-1] = x
411 #
413 x_list = self.fpn(x_input, x_list)
414 x = x_list[0]
416 x = self.head(x) if (self.head is not None) else x
418 if self.model_config.final_prediction:
419 # prediction
420 x = self.pred(x)
422 # final prediction is the upsampled one
423 if self.model_config.final_upsample:
424 x = self.upsample(x)
426 if (not self.training) and (self.output_type == 'segmentation'):
427 x = torch.argmax(x, dim=1, keepdim=True)
429 assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
430 #
431 return x
434 ###########################################
435 class BiFPNLitePixel2PixelASPP(Pixel2PixelNet):
436 def __init__(self, base_model, model_config):
437 super().__init__(base_model, BiFPNLitePixel2PixelDecoder, model_config)
440 ###########################################
441 def bifpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
442 model_config = get_config_bifpnlitep2p_mnv2().merge_from(model_config)
443 # encoder setup
444 model_config_e = model_config.clone()
445 base_model = MobileNetV2TVMI4(model_config_e)
446 # decoder setup
447 model = BiFPNLitePixel2PixelASPP(base_model, model_config)
449 num_inputs = len(model_config.input_channels)
450 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
451 if num_inputs > 1:
452 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
453 '^classifier.': 'encoder.classifier.',
454 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
455 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
456 else:
457 change_names_dict = {'^features.': 'encoder.features.',
458 '^classifier.': 'encoder.classifier.',
459 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
460 #
462 if pretrained:
463 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
464 #
465 return model, change_names_dict
469 ##################
470 # similar to the original fpn model with extra convolutions with strides (no aspp)
471 def bifpnlite_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
472 model_config = get_config_bifpnlitep2p_mnv2().merge_from(model_config)
473 model_config.use_aspp = False
474 return bifpnlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
477 ###########################################
478 # here this is nothing specific about bgr in this model
479 # but is just a reminder that regnet models are typically trained with bgr input
480 def bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained=None, base_model_class=None):
481 # encoder setup
482 model_config_e = model_config.clone()
483 base_model = base_model_class(model_config_e)
484 # decoder setup
485 model = BiFPNLitePixel2PixelASPP(base_model, model_config)
487 # the pretrained model provided by torchvision and what is defined here differs slightly
488 # note: that this change_names_dict will take effect only if the direct load fails
489 # finally take care of the change for deeplabv3lite (features->encoder.features)
490 num_inputs = len(model_config.input_channels)
491 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
492 if num_inputs > 1:
493 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
494 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
495 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
496 else:
497 change_names_dict = {'^stem.': 'encoder.features.stem.',
498 '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
499 '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
500 '^features.': 'encoder.features.',
501 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
502 #
504 if pretrained:
505 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
506 state_dict_name=['state_dict','model_state'])
507 else:
508 # need to use state_dict_name as the checkpoint uses a different name for state_dict
509 # provide a custom load_weighs for the model
510 def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
511 state_dict_name=['state_dict','model_state']):
512 xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
513 state_dict_name=state_dict_name)
514 #
515 model.load_weights = load_weights_func
517 return model, change_names_dict
520 ###########################################
521 # config settings for mobilenetv2 backbone
522 def get_config_bifpnlite_regnetx400mf():
523 # only the delta compared to the one defined for mobilenetv2
524 model_config = get_config_bifpnlitep2p_mnv2()
525 model_config.group_size_dw = 16
526 model_config.shortcut_channels = (32,64,160,384)
527 return model_config
529 def bifpnlite_pixel2pixel_aspp_regnetx400mf(model_config, pretrained=None):
530 model_config = get_config_bifpnlite_regnetx400mf().merge_from(model_config)
531 return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX400MFMI4)
534 bifpnlite_pixel2pixel_aspp_regnetx400mf_bgr = bifpnlite_pixel2pixel_aspp_regnetx400mf
537 ###########################################
538 # config settings for mobilenetv2 backbone
539 def get_config_bifpnlite_regnetx800mf():
540 # only the delta compared to the one defined for mobilenetv2
541 model_config = get_config_bifpnlitep2p_mnv2()
542 model_config.group_size_dw = 16
543 model_config.shortcut_channels = (64,128,288,672)
544 return model_config
546 def bifpnlite_pixel2pixel_aspp_regnetx800mf(model_config, pretrained=None):
547 model_config = get_config_bifpnlite_regnetx800mf().merge_from(model_config)
548 return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX800MFMI4)
551 bifpnlite_pixel2pixel_aspp_regnetx800mf_bgr = bifpnlite_pixel2pixel_aspp_regnetx800mf
554 ###########################################
555 # config settings for mobilenetv2 backbone
556 def get_config_bifpnlite_regnetx1p6gf():
557 # only the delta compared to the one defined for mobilenetv2
558 model_config = get_config_bifpnlitep2p_mnv2()
559 # group size is 24. make the decoder channels multiples of 24
560 model_config.group_size_dw = 24
561 model_config.decoder_chan = 168 #264
562 model_config.aspp_chan = 168 #264
563 model_config.fpn_chan = 168 #264
564 model_config.head_chan = 168 #264
565 model_config.shortcut_channels = (72, 168, 408, 912)
566 return model_config
569 def bifpnlite_pixel2pixel_aspp_regnetx1p6gf(model_config, pretrained=None):
570 model_config = get_config_bifpnlite_regnetx1p6gf().merge_from(model_config)
571 return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX1p6GFMI4)
574 bifpnlite_pixel2pixel_aspp_regnetx1p6gf_bgr = bifpnlite_pixel2pixel_aspp_regnetx1p6gf
577 ###########################################
578 # config settings for mobilenetv2 backbone
579 def get_config_bifpnlite_regnetx3p2gf():
580 # only the delta compared to the one defined for mobilenetv2
581 model_config = get_config_bifpnlitep2p_mnv2()
582 # group size is 48. make the decoder channels multiples of 48
583 model_config.group_size_dw = 48
584 model_config.decoder_chan = 192 #288
585 model_config.aspp_chan = 192 #288
586 model_config.fpn_chan = 192 #288
587 model_config.head_chan = 192 #288
588 model_config.shortcut_channels = (96, 192, 432, 1008)
589 return model_config
592 def bifpnlite_pixel2pixel_aspp_regnetx3p2gf(model_config, pretrained=None):
593 model_config = get_config_bifpnlite_regnetx3p2gf().merge_from(model_config)
594 return bifpnlite_pixel2pixel_aspp_regnetx(model_config, pretrained, base_model_class=RegNetX3p2GFMI4)
597 bifpnlite_pixel2pixel_aspp_regnetx3p2gf_bgr = bifpnlite_pixel2pixel_aspp_regnetx3p2gf