[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet.py
1 import torch
2 from .... import xnn
3 import torch.nn.functional as F
4 import copy
7 def split_output_channels(output, output_channels):
8 if isinstance(output, (list, tuple)):
9 return output
10 elif len(output_channels) == 1:
11 return [output]
12 else:
13 start_ch = 0
14 task_outputs = []
15 for num_ch in output_channels:
16 if len(output.shape) == 3:
17 task_outputs.append(output[start_ch:(start_ch + num_ch), ...])
18 elif len(output.shape) == 4:
19 task_outputs.append(output[:, start_ch:(start_ch + num_ch), ...])
20 else:
21 assert False, 'incorrect dimensions'
22 # --
23 start_ch += num_ch
24 # --
25 return task_outputs
28 ###########################################
29 class Pixel2PixelSimpleDecoder(torch.nn.Module):
30 def __init__(self, input_channels, output_channels):
31 super().__init__()
32 self.pred = xnn.layers.ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=3, \
33 normalization=(True, False), activation=(False,False))
35 def forward(self, x, x_features, x_list):
36 return self.pred(x_features)
40 ###########################################
41 class Pixel2PixelNet(torch.nn.Module):
42 def __init__(self, base_model, DecoderClass, model_config):
43 super().__init__()
44 self.normalisers = torch.nn.ModuleList([xnn.layers.DefaultNorm2d(i_chan) \
45 for i_chan in model_config.input_channels]) if model_config.normalize_input else None
46 self.encoder = base_model
47 self.output_channels = model_config.output_channels
48 self.num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
49 self.split_outputs = model_config.split_outputs
50 self.multi_task = xnn.layers.MultiTask(self.num_decoders, model_config.multi_task_type, model_config.output_type) if model_config.multi_task else None
52 #if model_config.freeze_encoder:
53 #xnn.utils.freeze_bn(self.encoder)
55 assert (self.num_decoders==0 or (self.num_decoders==len(model_config.output_type))), 'num_decoders specified is not correct'
56 self.decoders = torch.nn.ModuleList()
58 if self.num_decoders == 0:
59 self.decoders['0'] = Pixel2PixelSimpleDecoder(model_config.shortcut_channels[-1], sum(model_config.output_channels))
60 elif self.num_decoders > 0:
61 assert len(model_config.output_type) == len(model_config.output_channels), 'output_types and output_channels should have the same length'
63 for o_idx in range(self.num_decoders) :
64 model_config_d = model_config.split(o_idx)
65 # disable argmax in case multiple decoder are joint into one.
66 if (self.num_decoders == 1) and (model_config.output_type is not None):
67 model_config_d.output_type = ','.join(model_config.output_type)
68 #
69 decoder = DecoderClass(model_config_d)
70 #if model_config_d.freeze_decoder:
71 #xnn.utils.freeze_bn(decoder)
73 self.decoders.append(decoder)
74 #
75 #
77 self._initialize_weights()
80 def _initialize_weights(self):
81 for m in self.modules():
82 if isinstance(m, torch.nn.Conv2d):
83 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
84 if m.bias is not None:
85 torch.nn.init.constant_(m.bias, 0)
86 elif isinstance(m, torch.nn.BatchNorm2d):
87 if m.weight is not None:
88 torch.nn.init.constant_(m.weight, 1.0-(1e-5))
89 if m.bias is not None:
90 torch.nn.init.constant_(m.bias, 0)
91 elif isinstance(m, torch.nn.Linear):
92 m.weight.data.normal_(0, 0.01)
93 m.bias.data.zero_()
96 def forward(self, x_inp):
97 # BN based normalising
98 x_list = [norm(x) for (x, norm) in zip(x_inp, self.normalisers)] if self.normalisers else x_inp
100 # base encoder module
101 x_features, x_list = self.encoder(x_list)
103 x_features_split = self.multi_task(x_features) if self.multi_task else [x_features for _ in range(self.num_decoders)]
105 # decoder modules
106 x_out = []
107 for d_idx, d_name in enumerate(self.decoders):
108 decoder = self.decoders[d_idx]
109 x_feat = x_features_split[d_idx]
110 d_out = decoder(x_inp, x_feat, x_list)
111 x_out.append(d_out)
113 x_out = split_output_channels(x_out[0], self.output_channels) if (self.num_decoders <= 1 and self.split_outputs) else x_out
114 return x_out