[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / ops / roi_pool.py
1 import torch
2 from torch import nn, Tensor
4 from torch.nn.modules.utils import _pair
5 from torch.jit.annotations import List, BroadcastingList2
7 from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
10 def roi_pool(
11 input: Tensor,
12 boxes: Tensor,
13 output_size: BroadcastingList2[int],
14 spatial_scale: float = 1.0,
15 ) -> Tensor:
16 """
17 Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
19 Arguments:
20 input (Tensor[N, C, H, W]): input tensor
21 boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
22 format where the regions will be taken from. If a single Tensor is passed,
23 then the first column should contain the batch index. If a list of Tensors
24 is passed, then each Tensor will correspond to the boxes for an element i
25 in a batch
26 output_size (int or Tuple[int, int]): the size of the output after the cropping
27 is performed, as (height, width)
28 spatial_scale (float): a scaling factor that maps the input coordinates to
29 the box coordinates. Default: 1.0
31 Returns:
32 output (Tensor[K, C, output_size[0], output_size[1]])
33 """
34 check_roi_boxes_shape(boxes)
35 rois = boxes
36 output_size = _pair(output_size)
37 if not isinstance(rois, torch.Tensor):
38 rois = convert_boxes_to_roi_format(rois)
39 output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
40 output_size[0], output_size[1])
41 return output
44 class RoIPool(nn.Module):
45 """
46 See roi_pool
47 """
48 def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
49 super(RoIPool, self).__init__()
50 self.output_size = output_size
51 self.spatial_scale = spatial_scale
53 def forward(self, input: Tensor, rois: Tensor) -> Tensor:
54 return roi_pool(input, rois, self.output_size, self.spatial_scale)
56 def __repr__(self) -> str:
57 tmpstr = self.__class__.__name__ + '('
58 tmpstr += 'output_size=' + str(self.output_size)
59 tmpstr += ', spatial_scale=' + str(self.spatial_scale)
60 tmpstr += ')'
61 return tmpstr