]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite_internal.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / deeplabv3lite_internal.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 #   list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 #   this list of conditions and the following disclaimer in the documentation
13 #   and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 #   contributors may be used to endorse or promote products derived from
17 #   this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
32 """
33 Reference:
35 Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,
36 Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam,
37 Google Inc., https://arxiv.org/pdf/1802.02611.pdf
38 """
40 import numpy as np
41 from collections import OrderedDict
42 import torch
43 from .... import xnn
46 from .pixel2pixelnet import *
48 try: from .pixel2pixelnet_internal import *
49 except: pass
51 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, MobileNetV2TVNV12MI4, MobileNetV2TVGWSMI4
52 from .deeplabv3lite import DeepLabV3LiteDecoder, DeepLabV3Lite
53 from .deeplabv3lite import get_config_deeplav3lite_mnv2, deeplabv3lite_mobilenetv2_tv
56 __all__ = ['get_config_deeplav3lite_mnv2_gws', 'deeplabv3lite_mobilenetv2_tv_gws',
57            'deeplabv3lite_mobilenetv2_tv_es32', 'deeplabv3lite_mobilenetv2_tv_mi4_es32',
58                    'deeplabv3lite_mobilenetv2_tv_nv12', 'student_teacher_learner_nv12']
61 ######################################
62 class DeepLabV3LiteMobileNetV2TVNV12(DeepLabV3Lite):
63     def __init__(self, model_config):
64         model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
65         # encoder setup
66         model_config_e = model_config.clone()
67         base_model = MobileNetV2TVNV12MI4(model_config=model_config_e)
68         # decoder setup
69         super().__init__(base_model, model_config)
72 def deeplabv3lite_mobilenetv2_tv_nv12(model_config, pretrained=False):
73     model = DeepLabV3LiteMobileNetV2TVNV12(model_config)
74     num_inputs = len(model_config.input_channels)
75     num_decoders = len(model_config.output_channels) if (
76                 model_config.num_decoders is None) else model_config.num_decoders
77     if num_inputs > 1:
78         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
79                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
80                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
81     else:
82         change_names_dict = {'^features.': 'encoder.features.',
83                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
84     #
86     if pretrained:
87         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
89     return model, change_names_dict
92 class StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(DeepLabV3LiteMobileNetV2TVNV12):
93     def __init__(self, **kwargs):
94         super().__init__(**kwargs)
95         self.teacher, _ = deeplabv3lite_mobilenetv2_tv(**kwargs)
96         self.teacher.encoder.features = self.teacher.encoder.features[:1]
97         self.encoder.features = self.encoder.features[:1]
99         self.decoders = None
100         self.teacher.decoders = None
101         self.sub = xnn.layers.SubtractBlock(signed=True)
103     def forward(self, x):
104         x = x[0]
105         pred = self.encoder.features(x)
106         target = self.teacher.encoder.features(x[2])
107         #diff = target - pred
108         diff = self.sub((pred, target))
109         return diff
112 def student_teacher_learner_nv12(model_config, pretrained=None):
113     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
114     # encoder setup
115     model_config_e = model_config.clone()
116     model = StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(model_config=model_config_e)
117     change_names_dict = {'^encoder.features.': 'teacher.encoder.features.', '^features.': 'teacher.encoder.features.',
118                          '^decoders.': 'teacher.decoders.'}
120     if pretrained:
121         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
123     return model, change_names_dict
126 ###########################################
127 # Groupwise Seperable (GWS) convolutions
128 def get_config_deeplav3lite_mnv2_gws():
129     model_config = get_config_deeplav3lite_mnv2()
130     model_config.groupwise_sep = True
131     model_config.shortcut_channels = (64, 64*5)
132     model_config.shortcut_out = 56
133     model_config.decoder_chan = 252
134     model_config.aspp_chan = 252
135     model_config.aspp_grps = 4
136     model_config.fastdown = False
137     
138     return model_config
141 class DeepLabV3LiteGWSDecoder(DeepLabV3LiteDecoder):
142     def __init__(self, model_config):
143         super().__init__(model_config)
146 class DeepLabV3LiteGWS(Pixel2PixelNet):
147     def __init__(self, base_model, model_config):
148         model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)
149         super().__init__(base_model, DeepLabV3LiteGWSDecoder, model_config)
152 def deeplabv3lite_mobilenetv2_tv_gws(model_config, pretrained=None):
153     model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)
155     #adjust shortcut channels to accomodate gropus
156     flr = lambda a : (a//model_config.group_size_dws)*model_config.group_size_dws
157     enc_dec_dws_ratio_lcm = int(np.lcm(model_config.group_size_dws, 4))
158     flr_lcm = lambda a: (a // enc_dec_dws_ratio_lcm) * enc_dec_dws_ratio_lcm
159     model_config.shortcut_channels = (flr(model_config.shortcut_channels[0]), flr_lcm(model_config.shortcut_channels[1]))
161     # encoder setup
162     model_config_e = model_config.clone()
163     model_config_e.output_stride = np.prod(model_config_e.strides)
164     base_model = MobileNetV2TVGWSMI4(model_config_e)
165     # decoder setup
166     # experimenting with hybrid depth wise separable i.e Ni/G is more than 1 for DWS and group size if set to Ni/G for pointwise conv.
167     # Also shuffle is used in between this two layers
168     model = DeepLabV3LiteGWS(base_model, model_config)
170     change_names_dict = {'^features.': 'encoder.features.'}
171     if pretrained:
172         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
174     return model, change_names_dict
177 ######################################
178 # DeepLabV3Lite, but with with encoder stride of 32
179 def deeplabv3lite_mobilenetv2_tv_es32(model_config, pretrained=None):
180     # encoder setup
181     model_config_e = model_config.clone()
182     model_config_e.strides = (2,2,2,2,2)
183     encoder_stride = np.prod(model_config_e.strides)
184     model_config_e.shortcut_strides = (8, encoder_stride)
185     base_model = MobileNetV2TVMI4(model_config_e)
186     # decoder setup
187     model = DeepLabV3Lite(base_model, model_config)
189     change_names_dict = {'^features.': 'encoder.features.'}
190     if pretrained:
191         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
193     return model, change_names_dict
196 ######################################
197 # DeepLabV3Lite, but with with encoder stride of 32
198 def deeplabv3lite_mobilenetv2_tv_mi4_es32(model_config, pretrained=None):
199     # encoder setup
200     model_config_e = model_config.clone()
201     model_config_e.strides = (2,2,2,2,2)
202     encoder_stride = np.prod(model_config_e.strides)
203     model_config_e.shortcut_strides = (8, encoder_stride)
204     base_model = MobileNetV2TVMI4(model_config_e)
205     # decoder setup
206     model = DeepLabV3Lite(base_model, model_config)
208     num_inputs = len(model_config.input_channels)
209     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
210     change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
211                          '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in
212                                                 range(num_inputs)],
213                          '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
215     if pretrained:
216         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
218     return model, change_names_dict