[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / deeplabv3lite_internal.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 # list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 # this list of conditions and the following disclaimer in the documentation
13 # and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
32 """
33 Reference:
35 Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,
36 Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam,
37 Google Inc., https://arxiv.org/pdf/1802.02611.pdf
38 """
40 import numpy as np
41 from collections import OrderedDict
42 import torch
43 from .... import xnn
46 from .pixel2pixelnet import *
48 try: from .pixel2pixelnet_internal import *
49 except: pass
51 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, MobileNetV2TVNV12MI4, MobileNetV2TVGWSMI4
52 from .deeplabv3lite import DeepLabV3LiteDecoder, DeepLabV3Lite
53 from .deeplabv3lite import get_config_deeplav3lite_mnv2, deeplabv3lite_mobilenetv2_tv
56 __all__ = ['get_config_deeplav3lite_mnv2_gws', 'deeplabv3lite_mobilenetv2_tv_gws',
57 'deeplabv3lite_mobilenetv2_tv_es32', 'deeplabv3lite_mobilenetv2_tv_mi4_es32',
58 'deeplabv3lite_mobilenetv2_tv_nv12', 'student_teacher_learner_nv12']
61 ######################################
62 class DeepLabV3LiteMobileNetV2TVNV12(DeepLabV3Lite):
63 def __init__(self, model_config):
64 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
65 # encoder setup
66 model_config_e = model_config.clone()
67 base_model = MobileNetV2TVNV12MI4(model_config=model_config_e)
68 # decoder setup
69 super().__init__(base_model, model_config)
72 def deeplabv3lite_mobilenetv2_tv_nv12(model_config, pretrained=False):
73 model = DeepLabV3LiteMobileNetV2TVNV12(model_config)
74 num_inputs = len(model_config.input_channels)
75 num_decoders = len(model_config.output_channels) if (
76 model_config.num_decoders is None) else model_config.num_decoders
77 if num_inputs > 1:
78 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
79 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
80 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
81 else:
82 change_names_dict = {'^features.': 'encoder.features.',
83 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
84 #
86 if pretrained:
87 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
89 return model, change_names_dict
92 class StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(DeepLabV3LiteMobileNetV2TVNV12):
93 def __init__(self, **kwargs):
94 super().__init__(**kwargs)
95 self.teacher, _ = deeplabv3lite_mobilenetv2_tv(**kwargs)
96 self.teacher.encoder.features = self.teacher.encoder.features[:1]
97 self.encoder.features = self.encoder.features[:1]
99 self.decoders = None
100 self.teacher.decoders = None
101 self.sub = xnn.layers.SubtractBlock(signed=True)
103 def forward(self, x):
104 x = x[0]
105 pred = self.encoder.features(x)
106 target = self.teacher.encoder.features(x[2])
107 #diff = target - pred
108 diff = self.sub((pred, target))
109 return diff
112 def student_teacher_learner_nv12(model_config, pretrained=None):
113 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
114 # encoder setup
115 model_config_e = model_config.clone()
116 model = StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(model_config=model_config_e)
117 change_names_dict = {'^encoder.features.': 'teacher.encoder.features.', '^features.': 'teacher.encoder.features.',
118 '^decoders.': 'teacher.decoders.'}
120 if pretrained:
121 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
123 return model, change_names_dict
126 ###########################################
127 # Groupwise Seperable (GWS) convolutions
128 def get_config_deeplav3lite_mnv2_gws():
129 model_config = get_config_deeplav3lite_mnv2()
130 model_config.groupwise_sep = True
131 model_config.shortcut_channels = (64, 64*5)
132 model_config.shortcut_out = 56
133 model_config.decoder_chan = 252
134 model_config.aspp_chan = 252
135 model_config.aspp_grps = 4
136 model_config.fastdown = False
138 return model_config
141 class DeepLabV3LiteGWSDecoder(DeepLabV3LiteDecoder):
142 def __init__(self, model_config):
143 super().__init__(model_config)
146 class DeepLabV3LiteGWS(Pixel2PixelNet):
147 def __init__(self, base_model, model_config):
148 model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)
149 super().__init__(base_model, DeepLabV3LiteGWSDecoder, model_config)
152 def deeplabv3lite_mobilenetv2_tv_gws(model_config, pretrained=None):
153 model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)
155 #adjust shortcut channels to accomodate gropus
156 flr = lambda a : (a//model_config.group_size_dws)*model_config.group_size_dws
157 enc_dec_dws_ratio_lcm = int(np.lcm(model_config.group_size_dws, 4))
158 flr_lcm = lambda a: (a // enc_dec_dws_ratio_lcm) * enc_dec_dws_ratio_lcm
159 model_config.shortcut_channels = (flr(model_config.shortcut_channels[0]), flr_lcm(model_config.shortcut_channels[1]))
161 # encoder setup
162 model_config_e = model_config.clone()
163 model_config_e.output_stride = np.prod(model_config_e.strides)
164 base_model = MobileNetV2TVGWSMI4(model_config_e)
165 # decoder setup
166 # experimenting with hybrid depth wise separable i.e Ni/G is more than 1 for DWS and group size if set to Ni/G for pointwise conv.
167 # Also shuffle is used in between this two layers
168 model = DeepLabV3LiteGWS(base_model, model_config)
170 change_names_dict = {'^features.': 'encoder.features.'}
171 if pretrained:
172 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
174 return model, change_names_dict
177 ######################################
178 # DeepLabV3Lite, but with with encoder stride of 32
179 def deeplabv3lite_mobilenetv2_tv_es32(model_config, pretrained=None):
180 # encoder setup
181 model_config_e = model_config.clone()
182 model_config_e.strides = (2,2,2,2,2)
183 encoder_stride = np.prod(model_config_e.strides)
184 model_config_e.shortcut_strides = (8, encoder_stride)
185 base_model = MobileNetV2TVMI4(model_config_e)
186 # decoder setup
187 model = DeepLabV3Lite(base_model, model_config)
189 change_names_dict = {'^features.': 'encoder.features.'}
190 if pretrained:
191 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
193 return model, change_names_dict
196 ######################################
197 # DeepLabV3Lite, but with with encoder stride of 32
198 def deeplabv3lite_mobilenetv2_tv_mi4_es32(model_config, pretrained=None):
199 # encoder setup
200 model_config_e = model_config.clone()
201 model_config_e.strides = (2,2,2,2,2)
202 encoder_stride = np.prod(model_config_e.strides)
203 model_config_e.shortcut_strides = (8, encoder_stride)
204 base_model = MobileNetV2TVMI4(model_config_e)
205 # decoder setup
206 model = DeepLabV3Lite(base_model, model_config)
208 num_inputs = len(model_config.input_channels)
209 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
210 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
211 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in
212 range(num_inputs)],
213 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
215 if pretrained:
216 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
218 return model, change_names_dict