[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / module_utils.py
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
index c12a26fdf73109f70d14f55452aa8f797a43ca16..9fcd6e9198eae2c3079e6de22c7f38abc3f76d20 100644 (file)
def is_activation(module):
is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
- layers.NoAct, layers.PAct2))
+ layers.PAct2, layers.QAct, layers.NoQAct))
return is_act
def is_pact2(module):
def is_conv_deconv(module):
return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
+def is_conv_deconv_linear(module):
+ return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))
+
def is_linear(module):
return isinstance(module, torch.nn.Linear)
def squeeze_list(inputs):
+ return inputs[0] if (is_list(inputs) and len(inputs)==1) else inputs
+
+
+def squeeze_list2(inputs):
return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs