]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py
docs update and minor fixes
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / resize_blocks.py
index 003c179dfcb156f66b05e65dd5eadee64ce83f75..8bab95da30ba8f0bcc0da3d1853e6fe7b0de9edb 100644 (file)
@@ -9,6 +9,7 @@ from .deconv_blocks import *
 # older ResizeTo, UpsampleTo. The older modules may be removed in a later version.
 ##############################################################################################
 
+# resize with output size or scale factor
 def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
     assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
     assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
@@ -32,6 +33,26 @@ def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=N
     return y
 
 
+# always use scale factor to do the rescaling. if scale factor is not provided, generate it from the size.
+def resize_with_scale_factor(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+    assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
+    assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
+    assert isinstance(x, torch.Tensor), 'must provide a single tensor as input'
+    if scale_factor is None:
+        if isinstance(size, torch.Tensor):
+            size = [float(s) for s in size]
+        elif isinstance(size, (int,float)):
+            size = [size,size]
+        #
+        if isinstance(size, (list,tuple)) and len(size) > 2:
+            size = size[-2:]
+        #
+        x_size = [float(s) for s in x.size()][-2:]
+        scale_factor = [float(s)/float(x_s) for (s,x_s) in zip(size,x_size)]
+    #
+    y = resize_with(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+
+
 class ResizeWith(torch.nn.Module):
     def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
         super().__init__()