[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / segmentation / _utils.py
1 from collections import OrderedDict
3 import torch
4 from torch import nn
5 from torch.nn import functional as F
8 class _SimpleSegmentationModel(nn.Module):
9 def __init__(self, backbone, classifier, aux_classifier=None):
10 super(_SimpleSegmentationModel, self).__init__()
11 self.backbone = backbone
12 self.classifier = classifier
13 self.aux_classifier = aux_classifier
15 # weight initialization
16 for m in self.modules():
17 if isinstance(m, torch.nn.Conv2d):
18 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
19 if m.bias is not None:
20 torch.nn.init.zeros_(m.bias)
21 elif isinstance(m, torch.nn.BatchNorm2d):
22 torch.nn.init.ones_(m.weight)
23 torch.nn.init.zeros_(m.bias)
24 elif isinstance(m, torch.nn.Linear):
25 torch.nn.init.normal_(m.weight, 0, 0.01)
26 torch.nn.init.zeros_(m.bias)
29 def forward(self, x):
30 input_shape = x.shape[-2:]
31 # contract: features is a dict of tensors
32 features = self.backbone(x)
34 result = OrderedDict()
35 x = features["out"]
36 x = self.classifier(x)
37 x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
38 result["out"] = x
40 if self.aux_classifier is not None:
41 x = features["aux"]
42 x = self.aux_classifier(x)
43 x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
44 result["aux"] = x
46 return result