[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet.py
1 import torch
2 from .... import xnn
3 import torch.nn.functional as F
4 import copy
7 def split_output_channels(output, output_channels):
8 if isinstance(output, (list, tuple)):
9 return output
10 elif len(output_channels) == 1:
11 return [output]
12 else:
13 start_ch = 0
14 task_outputs = []
15 for num_ch in output_channels:
16 if len(output.shape) == 3:
17 task_outputs.append(output[start_ch:(start_ch + num_ch), ...])
18 elif len(output.shape) == 4:
19 task_outputs.append(output[:, start_ch:(start_ch + num_ch), ...])
20 else:
21 assert False, 'incorrect dimensions'
22 # --
23 start_ch += num_ch
24 # --
25 return task_outputs
28 ###########################################
29 class Pixel2PixelSimpleDecoder(torch.nn.Module):
30 def __init__(self, input_channels, output_channels):
31 super().__init__()
32 self.pred = xnn.layers.ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=3, \
33 normalization=(True, False), activation=(False,False))
35 def forward(self, x, x_features, x_list):
36 return self.pred(x_features)
40 ###########################################
41 class Pixel2PixelNet(torch.nn.Module):
42 def __init__(self, base_model, DecoderClass, model_config):
43 super().__init__()
44 self.normalisers = torch.nn.ModuleList([xnn.layers.DefaultNorm2d(i_chan) \
45 for i_chan in model_config.input_channels]) if model_config.normalize_input else None
46 self.encoder = base_model
47 self.output_channels = model_config.output_channels
48 self.num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
49 self.split_outputs = model_config.split_outputs
50 self.multi_task = xnn.layers.MultiTask(num_splits=self.num_decoders, multi_task_type=model_config.multi_task_type, output_type=model_config.output_type,
51 multi_task_factors=model_config.multi_task_factors) if model_config.multi_task else None
53 #if model_config.freeze_encoder:
54 #xnn.utils.freeze_bn(self.encoder)
56 assert (self.num_decoders==0 or (self.num_decoders==len(model_config.output_type))), 'num_decoders specified is not correct'
57 self.decoders = torch.nn.ModuleList()
59 if self.num_decoders == 0:
60 self.decoders['0'] = Pixel2PixelSimpleDecoder(model_config.shortcut_channels[-1], sum(model_config.output_channels))
61 elif self.num_decoders > 0:
62 assert len(model_config.output_type) == len(model_config.output_channels), 'output_types and output_channels should have the same length'
64 for o_idx in range(self.num_decoders) :
65 model_config_d = model_config.split(o_idx)
66 # disable argmax in case multiple decoder are joint into one.
67 if (self.num_decoders == 1) and (model_config.output_type is not None):
68 model_config_d.output_type = ','.join(model_config.output_type)
69 #
70 decoder = DecoderClass(model_config_d)
71 #if model_config_d.freeze_decoder:
72 #xnn.utils.freeze_bn(decoder)
74 self.decoders.append(decoder)
75 #
76 #
78 self._initialize_weights()
81 def _initialize_weights(self):
82 for m in self.modules():
83 if isinstance(m, torch.nn.Conv2d):
84 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
85 if m.bias is not None:
86 torch.nn.init.constant_(m.bias, 0)
87 elif isinstance(m, torch.nn.BatchNorm2d):
88 if m.weight is not None:
89 torch.nn.init.constant_(m.weight, 1.0-(1e-5))
90 if m.bias is not None:
91 torch.nn.init.constant_(m.bias, 0)
92 elif isinstance(m, torch.nn.Linear):
93 m.weight.data.normal_(0, 0.01)
94 m.bias.data.zero_()
97 def forward(self, x_inp):
98 # BN based normalising
99 x_list = [norm(x) for (x, norm) in zip(x_inp, self.normalisers)] if self.normalisers else x_inp
101 # base encoder module
102 x_features, x_list = self.encoder(x_list)
104 x_features_split = self.multi_task(x_features) if self.multi_task else [x_features for _ in range(self.num_decoders)]
106 # decoder modules
107 x_out = []
108 for d_idx, d_name in enumerate(self.decoders):
109 decoder = self.decoders[d_idx]
110 x_feat = x_features_split[d_idx]
111 d_out = decoder(x_inp, x_feat, x_list)
112 x_out.append(d_out)
114 x_out = split_output_channels(x_out[0], self.output_channels) if (self.num_decoders <= 1 and self.split_outputs) else x_out
115 return x_out