torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / multi_input_net.py
1 import copy
2 import torch
3 from ... import xnn
4 from .mobilenetv2 import MobileNetV2TV
5 from .resnet import resnet50_with_model_config
7 try: from .mobilenetv2_ericsun_internal import *
8 except: pass
10 try: from .mobilenetv2_internal import *
11 except: pass
13 __all__ = ['MultiInputNet', 'mobilenet_v2_tv_mi4', 'mobilenet_v2_tv_gws_mi4', 'mobilenet_v2_ericsun_mi4',
14            'MobileNetV2TVMI4', 'MobileNetV2TVNV12MI4', 'ResNet50MI4']
17 ###################################################
18 def get_config():
19     model_config = xnn.utils.ConfigNode()
20     model_config.num_inputs = 1
21     model_config.input_channels = 3
22     model_config.num_classes = 1000
23     model_config.fuse_channels = 0
24     model_config.intermediate_outputs = False
25     model_config.num_input_blocks = 0
26     model_config.shared_weights = False
27     model_config.fuse_stride = 1
28     return model_config
31 ###################################################
32 class MultiInputNet(torch.nn.Module):
33     def __init__(self, Model, model_config, pretrained=None):
34         model_config = get_config().merge_from(model_config)
35         super().__init__()
37         self.num_classes = model_config.num_classes
38         self.num_inputs = len(model_config.input_channels)
39         self.input_channels = model_config.input_channels
40         self.fuse_channels = model_config.fuse_channels
41         self.intermediate_outputs = model_config.intermediate_outputs
42         self.num_input_blocks = model_config.num_input_blocks
43         self.shared_weights = model_config.shared_weights
44         self.fuse_stride = model_config.fuse_stride
46         # in case of multi input net, each input encoder will be a copy of each other
47         model_config_s = model_config.clone()
48         model_config_s.input_channels = model_config.input_channels[0]
49         model = Model(model_config=model_config_s)
51         copy_attributes = [n for n, _ in model.named_children()]
52         for attr in copy_attributes:
53             if hasattr(model, attr):
54                 val = getattr(model, attr, None)
55                 setattr(self, attr, val)
57         if self.num_inputs>1:
58             self.features = self.create_multi_input_features(self.features, self.num_inputs, self.num_input_blocks,
59                                                         self.fuse_channels, self.shared_weights, self.fuse_stride)
61         self._initialize_weights()
63         if model_config.num_inputs>1 and pretrained:
64             change_names_dict = {'^features.': ['features.stream{}.'.format(stream) for stream in range(model_config.num_inputs)]}
65             xnn.utils.load_weights(self, model_config.pretrained, change_names_dict, ignore_size=True, verbose=True)
66         elif pretrained:
67             xnn.utils.load_weights(self, model_config.pretrained, change_names_dict=None, ignore_size=True, verbose=True)
70     def _initialize_weights(self):
71         for m in self.modules():
72             if isinstance(m, torch.nn.Conv2d):
73                 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
74                 if m.bias is not None:
75                     m.bias.data.zero_()
76             elif isinstance(m, torch.nn.BatchNorm2d):
77                 if m.weight is not None:
78                     torch.nn.init.constant_(m.weight, 1)
79                 if m.bias is not None:
80                     torch.nn.init.constant_(m.bias, 0)
81             elif isinstance(m, torch.nn.Linear):
82                 m.weight.data.normal_(0, 0.01)
83                 m.bias.data.zero_()
86     def forward(self, x):
87         if self.num_inputs>1:
88             x, outputs = self.forward_multi_input_features(x, self.features, self.num_inputs, self.num_input_blocks,
89                                                       self.fuse_channels, self.shared_weights)
90         else:
91             # TODO: Cleanup. It should not be done in this complicated way.
92             # To print the correct size of features.
93             outputs = []
94             x = x[0] if xnn.utils.is_list(x) else x
95             for block_id, block in enumerate(self.features):
96                 if isinstance(block, torch.nn.AdaptiveAvgPool2d):
97                     xnn.utils.print_once('=> feature size is: ', x.size())
98                 #
99                 x = block(x)
100                 outputs += [x]
102         if self.num_classes is not None:
103             x = torch.flatten(x, 1)
104             x = self.classifier(x)
106         if self.intermediate_outputs:
107             return x, outputs
108         else:
109             return x
112     def create_multi_input_features(self, features, num_inputs, num_input_blocks, fuse_channels, shared_weights, fuse_stride):
113         if num_inputs == 1 or shared_weights:
114             return features
116         features_mi = torch.nn.ModuleDict()
117         for stream_idx in range(num_inputs):
118             feature_stream = []
119             for block_id in range(num_input_blocks):
120                 block = features[block_id] if (stream_idx == 0) else copy.deepcopy(features[block_id])
121                 feature_stream.append(block)
122             if stream_idx == 0:
123                 feature_stream.extend(features[num_input_blocks:])
124             #
125             stream_name = 'stream'+str(stream_idx)
126             features_mi[stream_name] = torch.nn.Sequential(*feature_stream)
128         features_mi['streamfuse'] = xnn.layers.ConvDWSepNormAct2d(fuse_channels*num_inputs, fuse_channels, stride=fuse_stride, kernel_size=3, activation=(None,None))
130         return features_mi
133     def forward_multi_input_features(self, x, features_mi, num_inputs, num_input_blocks, fuse_channels, shared_weights):
134         outputs = []
135         if num_inputs>1:
136             if isinstance(x, (list, tuple)):
137                 assert len(x) == num_inputs, 'incorrect input. number of inputs do not match'
138             else:
139                 assert x.size(1) == num_inputs*3, 'incorrect input. size of input does not match'
140                 x = xnn.layers.functional.channel_split_by_chunks(x, num_inputs)
142             # shallow copy, just to create a new list
143             x = list(x)
145             if shared_weights:
146                 for stream_idx in range(num_inputs):
147                     for block_index, block in enumerate(features_mi[:num_input_blocks]):
148                         x[stream_idx] = block(x[stream_idx])
149                         if stream_idx == 0:
150                             outputs += [[x[stream_idx]]]
151                         else:
152                             outputs[block_index] += [x[stream_idx]]
153                 fuse_layer = features_mi['streamfuse']
154                 x = torch.cat(x, dim=1)
155                 x = fuse_layer(x)
157                 outputs += [x]
159                 for block in features_mi[num_input_blocks:]:
160                     x = block(x)
161                     outputs += [x]
163             else:
164                 for stream_idx in range(num_inputs):
165                     stream_name = 'stream' + str(stream_idx)
166                     stream = features_mi[stream_name]
167                     for block_index, block in enumerate(stream[:num_input_blocks]):
168                         x[stream_idx] = block(x[stream_idx])
169                         if stream_idx == 0:
170                             outputs += [[x[stream_idx]]]
171                         else:
172                             outputs[block_index] += [x[stream_idx]]
174                 fuse_layer = features_mi['streamfuse']
175                 x = torch.cat(x, dim=1)
176                 x = fuse_layer(x)
178                 outputs += [x]
180                 stream0 = features_mi['stream0']
181                 for block in stream0[num_input_blocks:]:
182                     x = block(x)
183                     outputs += [x]
184             #
185         else:
186             for block in features_mi:
187                 x = block(x)
188                 outputs += [x]
189         #
191         return x, outputs
194 ###################################################
195 # these are the real multi input blocks
196 class MobileNetV2TVMI4(MultiInputNet):
197     def __init__(self, model_config):
198         model_config.num_input_blocks = 4
199         model_config.fuse_channels = 24
200         super().__init__(MobileNetV2TV, model_config)
202 mobilenet_v2_tv_mi4 = MobileNetV2TVMI4
205 # these are the real multi input blocks
206 class MobileNetV2EricsunMI4(MultiInputNet):
207     def __init__(self, model_config):
208         model_config.num_input_blocks = 4
209         model_config.fuse_channels = 24
210         super().__init__(MobileNetV2Ericsun, model_config)
212 mobilenet_v2_ericsun_mi4 = MobileNetV2EricsunMI4
215 # these are the real multi input blocks
216 class MobileNetV2TVNV12MI4(MultiInputNet):
217     def __init__(self, model_config):
218         model_config.num_input_blocks = 4
219         model_config.fuse_channels = 24
220         super().__init__(MobileNetV2TVNV12, model_config)
222 mobilenet_v2_tv_nv12_mi4 = MobileNetV2TVNV12MI4
224 # these are the real multi input blocks
225 class MobileNetV2TVGWSMI4(MultiInputNet):
226     def __init__(self, model_config):
227         model_config.num_input_blocks = 4
228         model_config.fuse_channels = 24
229         super().__init__(MobileNetV2TVGWS, model_config)
231 mobilenet_v2_tv_gws_mi4 = MobileNetV2TVGWSMI4
234 ###################################################
235 # thes are multi input blocks, but their num_input_blocks is set to 0
236 class ResNet50MI4(MultiInputNet):
237     def __init__(self, model_config):
238         model_config.num_input_blocks = 4
239         model_config.fuse_channels = 64
240         super().__init__(resnet50_with_model_config, model_config)