]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
re-implemented QuantTestModule using QuantTrainModule. constrain_bias added.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / module_utils.py
index c12a26fdf73109f70d14f55452aa8f797a43ca16..9fcd6e9198eae2c3079e6de22c7f38abc3f76d20 100644 (file)
@@ -10,7 +10,7 @@ def is_normalization(module):
 
 def is_activation(module):
     is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
-                                 layers.NoAct, layers.PAct2))
+                                 layers.PAct2, layers.QAct, layers.NoQAct))
     return is_act
 
 def is_pact2(module):
@@ -26,6 +26,9 @@ def is_deconv(module):
 def is_conv_deconv(module):
     return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
 
+def is_conv_deconv_linear(module):
+    return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))
+
 def is_linear(module):
     return isinstance(module, torch.nn.Linear)
 
@@ -106,6 +109,10 @@ def add_module_names(model):
 
 
 def squeeze_list(inputs):
+    return inputs[0] if (is_list(inputs) and len(inputs)==1) else inputs
+
+
+def squeeze_list2(inputs):
     return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs