release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet.py
1 import torch
2 from .... import xnn
3 import torch.nn.functional as F
4 import copy
7 def split_output_channels(output, output_channels):
8     if isinstance(output, (list, tuple)):
9         return output
10     elif len(output_channels) == 1:
11         return [output]
12     else:
13         start_ch = 0
14         task_outputs = []
15         for num_ch in output_channels:
16             if len(output.shape) == 3:
17                 task_outputs.append(output[start_ch:(start_ch + num_ch), ...])
18             elif len(output.shape) == 4:
19                 task_outputs.append(output[:, start_ch:(start_ch + num_ch), ...])
20             else:
21                 assert False, 'incorrect dimensions'
22             # --
23             start_ch += num_ch
24         # --
25         return task_outputs
28 ###########################################
29 class Pixel2PixelSimpleDecoder(torch.nn.Module):
30     def __init__(self, input_channels, output_channels):
31         super().__init__()
32         self.pred = xnn.layers.ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=3, \
33                                                    normalization=(True, False), activation=(False,False))
35     def forward(self, x, x_features, x_list):
36         return self.pred(x_features)
40 ###########################################
41 class Pixel2PixelNet(torch.nn.Module):
42     def __init__(self, base_model, DecoderClass, model_config):
43         super().__init__()
44         self.normalisers = torch.nn.ModuleList([xnn.layers.DefaultNorm2d(i_chan) \
45                             for i_chan in model_config.input_channels]) if model_config.normalize_input else None
46         self.encoder = base_model
47         self.output_channels = model_config.output_channels
48         self.num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
49         self.split_outputs = model_config.split_outputs
50         self.multi_task = xnn.layers.MultiTask(num_splits=self.num_decoders, multi_task_type=model_config.multi_task_type, output_type=model_config.output_type,
51                                                multi_task_factors=model_config.multi_task_factors) if model_config.multi_task else None
53         #if model_config.freeze_encoder:
54             #xnn.utils.freeze_bn(self.encoder)
56         assert (self.num_decoders==0 or (self.num_decoders==len(model_config.output_type))), 'num_decoders specified is not correct'
57         self.decoders = torch.nn.ModuleList()
59         if self.num_decoders == 0:
60             self.decoders['0'] = Pixel2PixelSimpleDecoder(model_config.shortcut_channels[-1], sum(model_config.output_channels))
61         elif self.num_decoders > 0:
62             assert len(model_config.output_type) == len(model_config.output_channels), 'output_types and output_channels should have the same length'
64             for o_idx in range(self.num_decoders) :
65                 model_config_d = model_config.split(o_idx)
66                 # disable argmax in case multiple decoder are joint into one.
67                 if (self.num_decoders == 1) and (model_config.output_type is not None):
68                     model_config_d.output_type = ','.join(model_config.output_type)
69                 #
70                 decoder = DecoderClass(model_config_d)
71                 #if model_config_d.freeze_decoder:
72                     #xnn.utils.freeze_bn(decoder)
74                 self.decoders.append(decoder)
75             #
76         #
78         self._initialize_weights()
81     def _initialize_weights(self):
82         for m in self.modules():
83             if isinstance(m, torch.nn.Conv2d):
84                 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
85                 if m.bias is not None:
86                     torch.nn.init.constant_(m.bias, 0)
87             elif isinstance(m, torch.nn.BatchNorm2d):
88                 if m.weight is not None:
89                     torch.nn.init.constant_(m.weight, 1.0-(1e-5))
90                 if m.bias is not None:
91                     torch.nn.init.constant_(m.bias, 0)
92             elif isinstance(m, torch.nn.Linear):
93                 m.weight.data.normal_(0, 0.01)
94                 m.bias.data.zero_()
97     def forward(self, x_inp):
98         # BN based normalising
99         x_list = [norm(x) for (x, norm) in zip(x_inp, self.normalisers)] if self.normalisers else x_inp
101         # base encoder module
102         x_features, x_list = self.encoder(x_list)
104         x_features_split = self.multi_task(x_features) if self.multi_task else [x_features for _ in range(self.num_decoders)]
106         # decoder modules
107         x_out = []
108         for d_idx, d_name in enumerate(self.decoders):
109             decoder = self.decoders[d_idx]
110             x_feat = x_features_split[d_idx]
111             d_out = decoder(x_inp, x_feat, x_list)
112             x_out.append(d_out)
114         x_out = split_output_channels(x_out[0], self.output_channels) if (self.num_decoders <= 1 and self.split_outputs) else x_out
115         return x_out