[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / segmentation / deeplabv3.py
1 import torch
2 from torch import nn
3 from torch.nn import functional as F
5 from ._utils import _SimpleSegmentationModel
6 from pytorch_jacinto_ai import xnn
9 __all__ = ["DeepLabV3"]
12 class DeepLabV3(_SimpleSegmentationModel):
13 """
14 Implements DeepLabV3 model from
15 `"Rethinking Atrous Convolution for Semantic Image Segmentation"
16 <https://arxiv.org/abs/1706.05587>`_.
18 Arguments:
19 backbone (nn.Module): the network used to compute the features for the model.
20 The backbone should return an OrderedDict[Tensor], with the key being
21 "out" for the last feature map used, and "aux" if an auxiliary classifier
22 is used.
23 classifier (nn.Module): module that takes the "out" element returned from
24 the backbone and returns a dense prediction.
25 aux_classifier (nn.Module, optional): auxiliary classifier used during training
26 """
27 pass
30 class DeepLabHead(nn.Sequential):
31 def __init__(self, in_channels, num_classes):
32 super(DeepLabHead, self).__init__(
33 ASPP(in_channels, [12, 24, 36]),
34 nn.Conv2d(256, 256, 3, padding=1, bias=False),
35 nn.BatchNorm2d(256),
36 nn.ReLU(),
37 nn.Conv2d(256, num_classes, 1)
38 )
41 class ASPPConv(nn.Sequential):
42 def __init__(self, in_channels, out_channels, dilation):
43 modules = [
44 nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
45 nn.BatchNorm2d(out_channels),
46 nn.ReLU()
47 ]
48 super(ASPPConv, self).__init__(*modules)
51 class ASPPPooling(nn.Sequential):
52 def __init__(self, in_channels, out_channels):
53 super(ASPPPooling, self).__init__(
54 nn.AdaptiveAvgPool2d(1),
55 nn.Conv2d(in_channels, out_channels, 1, bias=False),
56 nn.BatchNorm2d(out_channels),
57 nn.ReLU())
59 def forward(self, x):
60 size = x.shape[-2:]
61 for mod in self:
62 x = mod(x)
63 #
64 return xnn.layers.resize_with_scale_factor(x, size=size, mode='bilinear') #F.interpolate(x, size=size, mode='bilinear', align_corners=False)
67 class ASPP(nn.Module):
68 def __init__(self, in_channels, atrous_rates, out_channels=256):
69 super(ASPP, self).__init__()
70 modules = []
71 modules.append(nn.Sequential(
72 nn.Conv2d(in_channels, out_channels, 1, bias=False),
73 nn.BatchNorm2d(out_channels),
74 nn.ReLU()))
76 rates = tuple(atrous_rates)
77 for rate in rates:
78 modules.append(ASPPConv(in_channels, out_channels, rate))
80 modules.append(ASPPPooling(in_channels, out_channels))
82 self.convs = nn.ModuleList(modules)
84 self.project = nn.Sequential(
85 nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
86 nn.BatchNorm2d(out_channels),
87 nn.ReLU(),
88 nn.Dropout(0.5))
90 def forward(self, x):
91 res = []
92 for conv in self.convs:
93 res.append(conv(x))
94 res = torch.cat(res, dim=1)
95 return self.project(res)