0d6c5b6da449d782fa00e7e724f2e939999ae099
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / unetlite_pixel2pixel.py
1 import torch
2 import numpy as np
3 from .... import xnn
5 from .pixel2pixelnet import *
6 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
9 __all__ = ['UNetLitePixel2PixelASPP', 'UNetLitePixel2PixelDecoder',
10 'unetlite_pixel2pixel_aspp_mobilenetv2_tv', 'unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd',
11 'unetlite_pixel2pixel_aspp_resnet50', 'unetlite_pixel2pixel_aspp_resnet50_fd',
12 ]
14 # config settings for mobilenetv2 backbone
15 def get_config_unetlitep2p_mnv2():
16 model_config = xnn.utils.ConfigNode()
17 model_config.num_classes = None
18 model_config.num_decoders = None
19 model_config.intermediate_outputs = True
20 model_config.use_aspp = True
21 model_config.use_extra_strides = False
22 model_config.groupwise_sep = False
23 model_config.fastdown = False
24 model_config.width_mult = 1.0
25 model_config.target_input_ratio = 1
27 model_config.strides = (2,2,2,2,2)
28 encoder_stride = np.prod(model_config.strides)
29 model_config.shortcut_strides = (2,4,8,16,encoder_stride)
30 model_config.shortcut_channels = (16,24,32,96,320) # this is for mobilenetv2 - change for other networks
31 model_config.decoder_chan = 256
32 model_config.aspp_chan = 256
33 model_config.aspp_dil = (6,12,18)
35 model_config.kernel_size_smooth = 3
36 model_config.interpolation_type = 'upsample'
37 model_config.interpolation_mode = 'bilinear'
39 model_config.final_prediction = True
40 model_config.final_upsample = True
41 model_config.output_range = None
43 model_config.normalize_input = False
44 model_config.split_outputs = False
45 model_config.decoder_factor = 1.0
46 model_config.activation = xnn.layers.DefaultAct2d
47 model_config.linear_dw = False
48 model_config.normalize_gradients = False
49 model_config.freeze_encoder = False
50 model_config.freeze_decoder = False
51 model_config.multi_task = False
52 return model_config
55 ###########################################
56 class UNetLitePyramid(torch.nn.Module):
57 def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode):
58 super().__init__()
59 self.shortcut_strides = shortcut_strides
60 self.shortcut_channels = shortcut_channels
61 self.upsamples = torch.nn.ModuleList()
62 self.concats = torch.nn.ModuleList()
63 self.smooth_convs = torch.nn.ModuleList()
65 self.smooth_convs.append(None)
66 self.concats.append(None)
68 upstride = 2
69 activation2 = (activation, activation)
70 for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
71 self.upsamples.append(xnn.layers.UpsampleWith(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
72 self.concats.append(xnn.layers.CatBlock())
73 smooth_channels = max(minimum_channels, feat_chan)
74 self.smooth_convs.append( xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels, kernel_size=kernel_size_smooth, activation=activation2))
75 current_channels = smooth_channels
76 #
77 #
80 def forward(self, x_input, x_list):
81 in_shape = x_input.shape
82 x = x_list[-1]
84 outputs = []
86 x = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
87 outputs.append(x)
89 for idx, (concat, smooth_conv, s_stride, short_chan, upsample) in \
90 enumerate(zip(self.concats[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
91 # get the feature of lower stride
92 shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
93 shape_s[1] = short_chan
94 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
95 # upsample current output and concat to that
96 x = upsample(x)
97 x = concat((x,x_s)) if (concat is not None) else x
98 # smooth conv
99 x = smooth_conv(x) if (smooth_conv is not None) else x
100 # output
101 outputs.append(x)
102 #
103 return outputs[::-1]
106 ###########################################
107 class UNetLitePixel2PixelDecoder(torch.nn.Module):
108 def __init__(self, model_config):
109 super().__init__()
110 self.model_config = model_config
111 activation = self.model_config.activation
112 self.output_type = model_config.output_type
113 self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
115 self.rfblock = None
116 if self.model_config.use_aspp:
117 current_channels = self.model_config.shortcut_channels[-1]
118 aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
119 self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
120 current_channels = decoder_channels
121 elif self.model_config.use_extra_strides:
122 # a low complexity pyramid
123 current_channels = self.model_config.shortcut_channels[-3]
124 self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
125 xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
126 current_channels = decoder_channels
127 else:
128 current_channels = self.model_config.shortcut_channels[-1]
129 self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
130 current_channels = decoder_channels
131 #
133 minimum_channels = max(self.model_config.output_channels*2, 32)
134 shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
135 shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
136 self.unet = UNetLitePyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
137 self.model_config.interpolation_type, self.model_config.interpolation_mode)
138 current_channels = max(minimum_channels, shortcut_channels[-1])
140 # add prediction & upsample modules
141 if self.model_config.final_prediction:
142 add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
143 #
146 def forward(self, x_input, x, x_list):
147 assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
148 assert x is x_list[-1], 'the features must the last one in x_list'
149 x_input = x_input[0]
150 in_shape = x_input.shape
152 if self.model_config.use_extra_strides:
153 for blk in self.rfblock:
154 x = blk(x)
155 x_list += [x]
156 #
157 elif self.rfblock is not None:
158 x = self.rfblock(x)
159 x_list[-1] = x
160 #
162 x_list = self.unet(x_input, x_list)
163 x = x_list[0]
165 if self.model_config.final_prediction:
166 # prediction
167 x = self.pred(x)
169 # final prediction is the upsampled one
170 if self.model_config.final_upsample:
171 x = self.upsample(x)
173 if (not self.training) and (self.output_type == 'segmentation'):
174 x = torch.argmax(x, dim=1, keepdim=True)
176 assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
178 return x
181 ###########################################
182 class UNetLitePixel2PixelASPP(Pixel2PixelNet):
183 def __init__(self, base_model, model_config):
184 super().__init__(base_model, UNetLitePixel2PixelDecoder, model_config)
187 ###########################################
188 def unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
189 model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
190 # encoder setup
191 model_config_e = model_config.clone()
192 base_model = MobileNetV2TVMI4(model_config_e)
193 # decoder setup
194 model = UNetLitePixel2PixelASPP(base_model, model_config)
196 num_inputs = len(model_config.input_channels)
197 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
198 if num_inputs > 1:
199 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
200 '^classifier.': 'encoder.classifier.',
201 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
202 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
203 else:
204 change_names_dict = {'^features.': 'encoder.features.',
205 '^classifier.': 'encoder.classifier.',
206 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
207 #
209 if pretrained:
210 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
211 #
212 return model, change_names_dict
215 # fast down sampling model (encoder stride 64 model)
216 def unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
217 model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
218 model_config.fastdown = True
219 model_config.strides = (2,2,2,2,2)
220 model_config.shortcut_strides = (4,8,16,32,64)
221 model_config.shortcut_channels = (16,24,32,96,320)
222 model_config.decoder_chan = 256
223 model_config.aspp_chan = 256
224 return unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
227 ###########################################
228 def get_config_unetlitep2p_resnet50():
229 # only the delta compared to the one defined for mobilenetv2
230 model_config = get_config_unetlitep2p_mnv2()
231 model_config.shortcut_strides = (2,4,8,16,32)
232 model_config.shortcut_channels = (64,256,512,1024,2048)
233 return model_config
236 def unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
237 model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
238 # encoder setup
239 model_config_e = model_config.clone()
240 base_model = ResNet50MI4(model_config_e)
241 # decoder setup
242 model = UNetLitePixel2PixelASPP(base_model, model_config)
244 # the pretrained model provided by torchvision and what is defined here differs slightly
245 # note: that this change_names_dict will take effect only if the direct load fails
246 # finally take care of the change for unet (features->encoder.features)
247 num_inputs = len(model_config.input_channels)
248 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
249 if num_inputs > 1:
250 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
251 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
252 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
253 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
254 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
255 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
256 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
257 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
258 else:
259 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
260 '^bn1.': 'encoder.features.bn1.',
261 '^relu.': 'encoder.features.relu.',
262 '^maxpool.': 'encoder.features.maxpool.',
263 '^layer': 'encoder.features.layer',
264 '^features.': 'encoder.features.',
265 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
266 #
268 if pretrained:
269 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
271 return model, change_names_dict
274 def unetlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
275 model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
276 model_config.fastdown = True
277 model_config.strides = (2,2,2,2,2)
278 model_config.shortcut_strides = (2,4,8,16,32,64) #(4,8,16,32,64)
279 model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
280 model_config.decoder_chan = 256 #128
281 model_config.aspp_chan = 256 #128
282 return unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)