]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/segmentation/deeplabv3.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / segmentation / deeplabv3.py
1 import torch
2 from torch import nn
3 from torch.nn import functional as F
5 from ._utils import _SimpleSegmentationModel
6 from pytorch_jacinto_ai import xnn
9 __all__ = ["DeepLabV3"]
12 class DeepLabV3(_SimpleSegmentationModel):
13     """
14     Implements DeepLabV3 model from
15     `"Rethinking Atrous Convolution for Semantic Image Segmentation"
16     <https://arxiv.org/abs/1706.05587>`_.
18     Arguments:
19         backbone (nn.Module): the network used to compute the features for the model.
20             The backbone should return an OrderedDict[Tensor], with the key being
21             "out" for the last feature map used, and "aux" if an auxiliary classifier
22             is used.
23         classifier (nn.Module): module that takes the "out" element returned from
24             the backbone and returns a dense prediction.
25         aux_classifier (nn.Module, optional): auxiliary classifier used during training
26     """
27     pass
30 class DeepLabHead(nn.Sequential):
31     def __init__(self, in_channels, num_classes):
32         super(DeepLabHead, self).__init__(
33             ASPP(in_channels, [12, 24, 36]),
34             nn.Conv2d(256, 256, 3, padding=1, bias=False),
35             nn.BatchNorm2d(256),
36             nn.ReLU(),
37             nn.Conv2d(256, num_classes, 1)
38         )
41 class ASPPConv(nn.Sequential):
42     def __init__(self, in_channels, out_channels, dilation):
43         modules = [
44             nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
45             nn.BatchNorm2d(out_channels),
46             nn.ReLU()
47         ]
48         super(ASPPConv, self).__init__(*modules)
51 class ASPPPooling(nn.Sequential):
52     def __init__(self, in_channels, out_channels):
53         super(ASPPPooling, self).__init__(
54             nn.AdaptiveAvgPool2d(1),
55             nn.Conv2d(in_channels, out_channels, 1, bias=False),
56             nn.BatchNorm2d(out_channels),
57             nn.ReLU())
59     def forward(self, x):
60         size = x.shape[-2:]
61         for mod in self:
62             x = mod(x)
63         #
64         return xnn.layers.resize_with_scale_factor(x, size=size, mode='bilinear') #F.interpolate(x, size=size, mode='bilinear', align_corners=False)
67 class ASPP(nn.Module):
68     def __init__(self, in_channels, atrous_rates, out_channels=256):
69         super(ASPP, self).__init__()
70         modules = []
71         modules.append(nn.Sequential(
72             nn.Conv2d(in_channels, out_channels, 1, bias=False),
73             nn.BatchNorm2d(out_channels),
74             nn.ReLU()))
76         rates = tuple(atrous_rates)
77         for rate in rates:
78             modules.append(ASPPConv(in_channels, out_channels, rate))
80         modules.append(ASPPPooling(in_channels, out_channels))
82         self.convs = nn.ModuleList(modules)
84         self.project = nn.Sequential(
85             nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
86             nn.BatchNorm2d(out_channels),
87             nn.ReLU(),
88             nn.Dropout(0.5))
90     def forward(self, x):
91         res = []
92         for conv in self.convs:
93             res.append(conv(x))
94         res = torch.cat(res, dim=1)
95         return self.project(res)