]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/models/segmentation/deeplabv3.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / segmentation / deeplabv3.py
index 94215071728923e92539945555ce99bf2dd81337..7acc013ccb1434f1f6e7cbc2eb93bf6ae7d25405 100644 (file)
@@ -14,7 +14,7 @@ class DeepLabV3(_SimpleSegmentationModel):
     `"Rethinking Atrous Convolution for Semantic Image Segmentation"
     <https://arxiv.org/abs/1706.05587>`_.
 
-    Arguments:
+    Args:
         backbone (nn.Module): the network used to compute the features for the model.
             The backbone should return an OrderedDict[Tensor], with the key being
             "out" for the last feature map used, and "aux" if an auxiliary classifier
@@ -57,30 +57,30 @@ class ASPPPooling(nn.Sequential):
 
     def forward(self, x):
         size = x.shape[-2:]
-        x = super(ASPPPooling, self).forward(x)
+        for mod in self:
+            x = mod(x)
         return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
 
 
 class ASPP(nn.Module):
-    def __init__(self, in_channels, atrous_rates):
+    def __init__(self, in_channels, atrous_rates, out_channels=256):
         super(ASPP, self).__init__()
-        out_channels = 256
         modules = []
         modules.append(nn.Sequential(
             nn.Conv2d(in_channels, out_channels, 1, bias=False),
             nn.BatchNorm2d(out_channels),
             nn.ReLU()))
 
-        rate1, rate2, rate3 = tuple(atrous_rates)
-        modules.append(ASPPConv(in_channels, out_channels, rate1))
-        modules.append(ASPPConv(in_channels, out_channels, rate2))
-        modules.append(ASPPConv(in_channels, out_channels, rate3))
+        rates = tuple(atrous_rates)
+        for rate in rates:
+            modules.append(ASPPConv(in_channels, out_channels, rate))
+
         modules.append(ASPPPooling(in_channels, out_channels))
 
         self.convs = nn.ModuleList(modules)
 
         self.project = nn.Sequential(
-            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
+            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
             nn.BatchNorm2d(out_channels),
             nn.ReLU(),
             nn.Dropout(0.5))