cde3b4d04425ccf0321e9d886f04f81c229195e4
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / fpn_pixel2pixel.py
1 import torch
2 import numpy as np
3 from .... import xnn
5 from .pixel2pixelnet import *
6 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
9 __all__ = ['FPNPixel2PixelASPP', 'FPNPixel2PixelDecoder',
10 'fpn_pixel2pixel_aspp_mobilenetv2_tv', 'fpn_pixel2pixel_aspp_mobilenetv2_tv_fd', 'fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd',
11 # no aspp models
12 'fpn_pixel2pixel_mobilenetv2_tv', 'fpn_pixel2pixel_mobilenetv2_tv_fd',
13 # resnet models
14 'fpn_pixel2pixel_aspp_resnet50', 'fpn_pixel2pixel_aspp_resnet50_fd',
15 ]
18 ###########################################
19 class FPNPixel2PixelDecoder(torch.nn.Module):
20 def __init__(self, model_config):
21 super().__init__()
22 self.model_config = model_config
23 activation = self.model_config.activation
24 self.output_type = model_config.output_type
26 self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
28 self.rfblock = None
29 if self.model_config.use_aspp:
30 current_channels = self.model_config.shortcut_channels[-1]
31 aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
32 self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil,
33 avg_pool=False, activation=activation)
34 elif self.model_config.use_extra_strides:
35 # a low complexity pyramid
36 current_channels = self.model_config.shortcut_channels[-3]
37 self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
38 xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
39 else:
40 current_channels = self.model_config.shortcut_channels[-1]
41 self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
43 current_channels = decoder_channels
45 self.shortcuts = torch.nn.ModuleList()
46 self.smooth_convs = torch.nn.ModuleList()
48 for s_stride, feat_chan in zip(self.model_config.shortcut_strides[::-1][1:], self.model_config.shortcut_channels[::-1][1:]):
49 shortcut = xnn.layers.ConvNormAct2d(feat_chan, decoder_channels, kernel_size=1, activation=activation)
50 self.shortcuts.append(shortcut)
51 if self.model_config.smooth_conv:
52 smooth_conv = xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=self.model_config.kernel_size_smooth, activation=(activation,activation))
53 else:
54 smooth_conv = xnn.layers.BypassBlock()
56 self.smooth_convs.append(smooth_conv)
57 current_channels = decoder_channels
59 # prediction
60 if self.model_config.final_prediction:
61 self.pred = xnn.layers.ConvDWSepNormAct2d(current_channels, self.model_config.output_channels, kernel_size=3, normalization=(True,False), activation=(False,False))
63 upstride1 = 2
64 upstride2 = self.model_config.shortcut_strides[0]
65 self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
66 self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
69 # the upsampling is using functional form to support size based upsampling for odd sizes
70 # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
71 def forward(self, x, x_features, x_list):
72 assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
73 x_input = x[0]
74 in_shape = x_input.shape
76 # rfblock at output stride
77 if self.model_config.use_aspp:
78 x = self.rfblock(x_features)
79 elif self.model_config.use_extra_strides:
80 x = x_features
81 for blk in self.rfblock:
82 x = blk(x)
83 x_list += [x]
84 else:
85 x = self.rfblock(x_features)
87 for s_stride, shortcut, smooth_conv, short_chan in zip(self.model_config.shortcut_strides[::-1][1:], self.shortcuts, self.smooth_convs, self.model_config.shortcut_channels[::-1][1:]):
88 shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
89 shape_s[1] = short_chan
90 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
91 x_s = shortcut(x_s)
92 x = self.upsample1((x,x_s))
93 x = x + x_s
94 x = smooth_conv(x)
96 if self.model_config.final_prediction:
97 # prediction
98 x = self.pred(x)
100 # final prediction is the upsampled one
101 if self.model_config.final_upsample:
102 x = self.upsample2((x,x_input))
104 if (not self.training) and (self.output_type == 'segmentation'):
105 x = torch.argmax(x, dim=1, keepdim=True)
107 assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
109 return x
112 ###########################################
113 class FPNPixel2PixelASPP(Pixel2PixelNet):
114 def __init__(self, base_model, model_config):
115 super().__init__(base_model, FPNPixel2PixelDecoder, model_config)
118 ###########################################
119 # config settings for mobilenetv2 backbone
120 def get_config_fpnp2p_mnv2():
121 model_config = xnn.utils.ConfigNode()
122 model_config.num_classes = None
123 model_config.num_decoders = None
124 model_config.intermediate_outputs = True
125 model_config.use_aspp = True
126 model_config.use_extra_strides = False
127 model_config.groupwise_sep = False
128 model_config.fastdown = False
130 model_config.strides = (2,2,2,2,2)
131 encoder_stride = np.prod(model_config.strides)
132 model_config.shortcut_strides = (4,8,16,encoder_stride)
133 model_config.shortcut_channels = (24,32,96,320) # this is for mobilenetv2 - change for other networks
134 model_config.smooth_conv = True
135 model_config.decoder_chan = 256
136 model_config.aspp_chan = 256
137 model_config.aspp_dil = (6,12,18)
139 model_config.kernel_size_smooth = 3
140 model_config.interpolation_type = 'upsample'
141 model_config.interpolation_mode = 'bilinear'
143 model_config.final_prediction = True
144 model_config.final_upsample = True
146 model_config.normalize_input = False
147 model_config.split_outputs = False
148 model_config.decoder_factor = 1.0
149 model_config.activation = xnn.layers.DefaultAct2d
150 model_config.linear_dw = False
151 model_config.normalize_gradients = False
152 model_config.freeze_encoder = False
153 model_config.freeze_decoder = False
154 model_config.multi_task = False
155 return model_config
158 def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
159 model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
160 # encoder setup
161 model_config_e = model_config.clone()
162 base_model = MobileNetV2TVMI4(model_config_e)
163 # decoder setup
164 model = FPNPixel2PixelASPP(base_model, model_config)
166 num_inputs = len(model_config.input_channels)
167 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
168 if num_inputs > 1:
169 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
170 '^classifier.': 'encoder.classifier.',
171 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
172 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
173 else:
174 change_names_dict = {'^features.': 'encoder.features.',
175 '^classifier.': 'encoder.classifier.',
176 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
177 #
179 if pretrained:
180 model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
181 #
182 return model, change_names_dict
185 # fast down sampling model (encoder stride 64 model)
186 def fpn_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
187 model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
188 model_config.fastdown = True
189 model_config.strides = (2,2,2,2,2)
190 model_config.shortcut_strides = (8,16,32,64)
191 model_config.shortcut_channels = (24,32,96,320)
192 model_config.decoder_chan = 256
193 model_config.aspp_chan = 256
194 return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
197 # fast down sampling model (encoder stride 64 model) with fpn decoder channels 128
198 def fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
199 model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
200 model_config.fastdown = True
201 model_config.strides = (2,2,2,2,2)
202 model_config.shortcut_strides = (4,8,16,32,64)
203 model_config.shortcut_channels = (16,24,32,96,320)
204 model_config.decoder_chan = 128
205 model_config.aspp_chan = 128
206 return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
209 ##################
210 # similar to the original fpn model with extra convolutions with strides (no aspp)
211 def fpn_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
212 model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
213 model_config.use_aspp = False
214 model_config.use_extra_strides = True
215 model_config.shortcut_strides = (4, 8, 16, 32, 64, 128)
216 model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
217 return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
220 # similar to the original fpn model with extra convolutions with strides (no aspp) - fast down sampling model (encoder stride 64 model)
221 def fpn_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
222 model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
223 model_config.use_aspp = False
224 model_config.use_extra_strides = True
225 model_config.fastdown = True
226 model_config.strides = (2,2,2,2,2)
227 model_config.shortcut_strides = (8, 16, 32, 64, 128, 256)
228 model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
229 model_config.decoder_chan = 256
230 model_config.aspp_chan = 256
231 return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
234 ###########################################
235 def get_config_fpnp2p_resnet50():
236 # only the delta compared to the one defined for mobilenetv2
237 model_config = get_config_fpnp2p_mnv2()
238 model_config.shortcut_channels = (256,512,1024,2048)
239 return model_config
242 def fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
243 model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
244 # encoder setup
245 model_config_e = model_config.clone()
246 base_model = ResNet50MI4(model_config_e)
247 # decoder setup
248 model = FPNPixel2PixelASPP(base_model, model_config)
250 # the pretrained model provided by torchvision and what is defined here differs slightly
251 # note: that this change_names_dict will take effect only if the direct load fails
252 # finally take care of the change for fpn (features->encoder.features)
253 num_inputs = len(model_config.input_channels)
254 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
255 if num_inputs > 1:
256 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
257 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
258 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
259 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
260 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
261 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
262 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
263 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
264 else:
265 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
266 '^bn1.': 'encoder.features.bn1.',
267 '^relu.': 'encoder.features.relu.',
268 '^maxpool.': 'encoder.features.maxpool.',
269 '^layer': 'encoder.features.layer',
270 '^features.': 'encoder.features.',
271 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
272 #
274 if pretrained:
275 model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
277 return model, change_names_dict
280 def fpn_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
281 model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
282 model_config.fastdown = True
283 model_config.strides = (2,2,2,2,2)
284 model_config.shortcut_strides = (8,16,32,64) #(4,8,16,32,64)
285 model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
286 model_config.decoder_chan = 256 #128
287 model_config.aspp_chan = 256 #128
288 return fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)