[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / resize_blocks.py
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py
index 003c179dfcb156f66b05e65dd5eadee64ce83f75..8bab95da30ba8f0bcc0da3d1853e6fe7b0de9edb 100644 (file)
# older ResizeTo, UpsampleTo. The older modules may be removed in a later version.
##############################################################################################
+# resize with output size or scale factor
def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
return y
+# always use scale factor to do the rescaling. if scale factor is not provided, generate it from the size.
+def resize_with_scale_factor(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
+ assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
+ assert isinstance(x, torch.Tensor), 'must provide a single tensor as input'
+ if scale_factor is None:
+ if isinstance(size, torch.Tensor):
+ size = [float(s) for s in size]
+ elif isinstance(size, (int,float)):
+ size = [size,size]
+ #
+ if isinstance(size, (list,tuple)) and len(size) > 2:
+ size = size[-2:]
+ #
+ x_size = [float(s) for s in x.size()][-2:]
+ scale_factor = [float(s)/float(x_s) for (s,x_s) in zip(size,x_size)]
+ #
+ y = resize_with(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+
+
class ResizeWith(torch.nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super().__init__()