]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/segmentation/_utils.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / segmentation / _utils.py
1 from collections import OrderedDict
3 import torch
4 from torch import nn
5 from torch.nn import functional as F
8 class _SimpleSegmentationModel(nn.Module):
9     def __init__(self, backbone, classifier, aux_classifier=None):
10         super(_SimpleSegmentationModel, self).__init__()
11         self.backbone = backbone
12         self.classifier = classifier
13         self.aux_classifier = aux_classifier
15         # weight initialization
16         for m in self.modules():
17             if isinstance(m, torch.nn.Conv2d):
18                 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
19                 if m.bias is not None:
20                     torch.nn.init.zeros_(m.bias)
21             elif isinstance(m, torch.nn.BatchNorm2d):
22                 torch.nn.init.ones_(m.weight)
23                 torch.nn.init.zeros_(m.bias)
24             elif isinstance(m, torch.nn.Linear):
25                 torch.nn.init.normal_(m.weight, 0, 0.01)
26                 torch.nn.init.zeros_(m.bias)
29     def forward(self, x):
30         input_shape = x.shape[-2:]
31         # contract: features is a dict of tensors
32         features = self.backbone(x)
34         result = OrderedDict()
35         x = features["out"]
36         x = self.classifier(x)
37         x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
38         result["out"] = x
40         if self.aux_classifier is not None:
41             x = features["aux"]
42             x = self.aux_classifier(x)
43             x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
44             result["aux"] = x
46         return result