[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / ops / feature_pyramid_network.py
diff --git a/modules/pytorch_jacinto_ai/xvision/ops/feature_pyramid_network.py b/modules/pytorch_jacinto_ai/xvision/ops/feature_pyramid_network.py
index 6c6c65bb4f2bc26a8adcb4038a57123c47486cbd..7d72769ab070f99e4055490c7082defe9f034137 100644 (file)
from collections import OrderedDict
-import torch
import torch.nn.functional as F
-from torch import nn
+from torch import nn, Tensor
+
+from typing import Tuple, List, Dict, Optional
+
+
+class ExtraFPNBlock(nn.Module):
+ """
+ Base class for the extra block in the FPN.
+
+ Args:
+ results (List[Tensor]): the result of the FPN
+ x (List[Tensor]): the original feature maps
+ names (List[str]): the names for each one of the
+ original feature maps
+
+ Returns:
+ results (List[Tensor]): the extended set of results
+ of the FPN
+ names (List[str]): the extended set of names for the results
+ """
+ def forward(
+ self,
+ results: List[Tensor],
+ x: List[Tensor],
+ names: List[str],
+ ) -> Tuple[List[Tensor], List[str]]:
+ pass
class FeaturePyramidNetwork(nn.Module):
The input to the model is expected to be an OrderedDict[Tensor], containing
the feature maps on top of which the FPN will be added.
- Arguments:
+ Args:
in_channels_list (list[int]): number of channels for each feature map that
is passed to the module
out_channels (int): number of channels of the FPN representation
Examples::
- >>> m = pytorch_jacinto_ai.xvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
+ >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
>>> # get some dummy data
>>> x = OrderedDict()
>>> x['feat0'] = torch.rand(1, 10, 64, 64)
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
"""
-
- def __init__(self, in_channels_list, out_channels, extra_blocks=None):
+ def __init__(
+ self,
+ in_channels_list: List[int],
+ out_channels: int,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+ ):
super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list:
if in_channels == 0:
- continue
+ raise ValueError("in_channels=0 is currently not supported")
inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.inner_blocks.append(inner_block_module)
self.layer_blocks.append(layer_block_module)
# initialize parameters now to avoid modifying the initialization of top_blocks
- for m in self.children():
+ for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0)
assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks
- def forward(self, x):
+ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
+ """
+ This is equivalent to self.inner_blocks[idx](x),
+ but torchscript doesn't support this yet
+ """
+ num_blocks = len(self.inner_blocks)
+ if idx < 0:
+ idx += num_blocks
+ i = 0
+ out = x
+ for module in self.inner_blocks:
+ if i == idx:
+ out = module(x)
+ i += 1
+ return out
+
+ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
+ """
+ This is equivalent to self.layer_blocks[idx](x),
+ but torchscript doesn't support this yet
+ """
+ num_blocks = len(self.layer_blocks)
+ if idx < 0:
+ idx += num_blocks
+ i = 0
+ out = x
+ for module in self.layer_blocks:
+ if i == idx:
+ out = module(x)
+ i += 1
+ return out
+
+ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
- Arguments:
+ Args:
x (OrderedDict[Tensor]): feature maps for each feature level.
Returns:
names = list(x.keys())
x = list(x.values())
- last_inner = self.inner_blocks[-1](x[-1])
+ last_inner = self.get_result_from_inner_blocks(x[-1], -1)
results = []
- results.append(self.layer_blocks[-1](last_inner))
- for feature, inner_block, layer_block in zip(
- x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
- ):
- if not inner_block:
- continue
- inner_lateral = inner_block(feature)
+ results.append(self.get_result_from_layer_blocks(last_inner, -1))
+
+ for idx in range(len(x) - 2, -1, -1):
+ inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
feat_shape = inner_lateral.shape[-2:]
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = inner_lateral + inner_top_down
- results.insert(0, layer_block(last_inner))
+ results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
if self.extra_blocks is not None:
results, names = self.extra_blocks(results, x, names)
return out
-class ExtraFPNBlock(nn.Module):
- """
- Base class for the extra block in the FPN.
-
- Arguments:
- results (List[Tensor]): the result of the FPN
- x (List[Tensor]): the original feature maps
- names (List[str]): the names for each one of the
- original feature maps
-
- Returns:
- results (List[Tensor]): the extended set of results
- of the FPN
- names (List[str]): the extended set of names for the results
- """
- def forward(self, results, x, names):
- pass
-
-
class LastLevelMaxPool(ExtraFPNBlock):
"""
Applies a max_pool2d on top of the last feature map
"""
- def forward(self, x, y, names):
+ def forward(
+ self,
+ x: List[Tensor],
+ y: List[Tensor],
+ names: List[str],
+ ) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
- def __init__(self, in_channels, out_channels):
+ def __init__(self, in_channels: int, out_channels: int):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
- def forward(self, p, c, names):
+ def forward(
+ self,
+ p: List[Tensor],
+ c: List[Tensor],
+ names: List[str],
+ ) -> Tuple[List[Tensor], List[str]]:
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)