[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / common_blocks.py
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py
index 8f9c6577943fb879091b8251a5648deb95147c95..0e9b4b40c5a6905c7d412cfaa90ed1ada1bcabab 100644 (file)
def __repr__(self):
return 'AddBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
+# sub
+class SubtractBlock(torch.nn.Module):
+ def __init__(self, inplace=False, signed=True, *args, **kwargs):
+ super().__init__()
+ self.inplace = inplace
+ self.signed = signed
+
+ def forward(self, x):
+ assert isinstance(x, (list,tuple)), 'input to sub block must be a list or tuple'
+ y = x[0]
+ for i in range(1,len(x)):
+ y = y - x[i]
+ #
+ return y
+
+ def __repr__(self):
+ return 'SubtractBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
+
###########################################################
# mult