]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet.py
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet.py
1 import torch
2 from .... import xnn
3 import torch.nn.functional as F
4 import copy
7 def split_output_channels(output, output_channels):
8     if isinstance(output, (list, tuple)):
9         return output
10     elif len(output_channels) == 1:
11         return [output]
12     else:
13         start_ch = 0
14         task_outputs = []
15         for num_ch in output_channels:
16             if len(output.shape) == 3:
17                 task_outputs.append(output[start_ch:(start_ch + num_ch), ...])
18             elif len(output.shape) == 4:
19                 task_outputs.append(output[:, start_ch:(start_ch + num_ch), ...])
20             else:
21                 assert False, 'incorrect dimensions'
22             # --
23             start_ch += num_ch
24         # --
25         return task_outputs
28 ###########################################
29 class Pixel2PixelSimpleDecoder(torch.nn.Module):
30     def __init__(self, input_channels, output_channels):
31         super().__init__()
32         self.pred = xnn.layers.ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=3, \
33                                                    normalization=(True, False), activation=(False,False))
35     def forward(self, x, x_features, x_list):
36         return self.pred(x_features)
40 ###########################################
41 class Pixel2PixelNet(torch.nn.Module):
42     def __init__(self, base_model, DecoderClass, model_config):
43         super().__init__()
44         self.normalisers = torch.nn.ModuleList([xnn.layers.DefaultNorm2d(i_chan) \
45                             for i_chan in model_config.input_channels]) if model_config.normalize_input else None
46         self.encoder = base_model
47         self.output_channels = model_config.output_channels
48         self.num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
49         self.split_outputs = model_config.split_outputs
50         self.multi_task = xnn.layers.MultiTask(self.num_decoders, model_config.multi_task_type, model_config.output_type) if model_config.multi_task else None
52         #if model_config.freeze_encoder:
53             #xnn.utils.freeze_bn(self.encoder)
55         assert (self.num_decoders==0 or (self.num_decoders==len(model_config.output_type))), 'num_decoders specified is not correct'
56         self.decoders = torch.nn.ModuleList()
58         if self.num_decoders == 0:
59             self.decoders['0'] = Pixel2PixelSimpleDecoder(model_config.shortcut_channels[-1], sum(model_config.output_channels))
60         elif self.num_decoders > 0:
61             assert len(model_config.output_type) == len(model_config.output_channels), 'output_types and output_channels should have the same length'
63             for o_idx in range(self.num_decoders) :
64                 model_config_d = model_config.split(o_idx)
65                 # disable argmax in case multiple decoder are joint into one.
66                 if (self.num_decoders == 1) and (model_config.output_type is not None):
67                     model_config_d.output_type = ','.join(model_config.output_type)
68                 #
69                 decoder = DecoderClass(model_config_d)
70                 #if model_config_d.freeze_decoder:
71                     #xnn.utils.freeze_bn(decoder)
73                 self.decoders.append(decoder)
74             #
75         #
77         self._initialize_weights()
80     def _initialize_weights(self):
81         for m in self.modules():
82             if isinstance(m, torch.nn.Conv2d):
83                 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
84                 if m.bias is not None:
85                     torch.nn.init.constant_(m.bias, 0)
86             elif isinstance(m, torch.nn.BatchNorm2d):
87                 if m.weight is not None:
88                     torch.nn.init.constant_(m.weight, 1.0-(1e-5))
89                 if m.bias is not None:
90                     torch.nn.init.constant_(m.bias, 0)
91             elif isinstance(m, torch.nn.Linear):
92                 m.weight.data.normal_(0, 0.01)
93                 m.bias.data.zero_()
96     def forward(self, x_inp):
97         # BN based normalising
98         x_list = [norm(x) for (x, norm) in zip(x_inp, self.normalisers)] if self.normalisers else x_inp
100         # base encoder module
101         x_features, x_list = self.encoder(x_list)
103         x_features_split = self.multi_task(x_features) if self.multi_task else [x_features for _ in range(self.num_decoders)]
105         # decoder modules
106         x_out = []
107         for d_idx, d_name in enumerate(self.decoders):
108             decoder = self.decoders[d_idx]
109             x_feat = x_features_split[d_idx]
110             d_out = decoder(x_inp, x_feat, x_list)
111             x_out.append(d_out)
113         x_out = split_output_channels(x_out[0], self.output_channels) if (self.num_decoders <= 1 and self.split_outputs) else x_out
114         return x_out