]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/layers/functional.py
cleanedup STE for QAT. Added RegNetX models
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / functional.py
index 996b4c92457ab0b634c4b53ef512f93480d9e889..16f1f3da0e1eea1e4c60b8f9c388bef4536fdbe2 100644 (file)
@@ -1,17 +1,25 @@
+import functools
 import torch
 from . import function
+from . import quant_ste
 
 
 ###################################################
-round_g = function.RoundG.apply
-round_sym_g = function.RoundSymG.apply
-round_up_g = function.RoundUpG.apply
-round2_g = function.Round2G.apply
-
-ceil_g = function.CeilG.apply
-ceil2_g = function.Ceil2G.apply
-
-quantize_dequantize_g = function.QuantizeDequantizeG.apply
+round_g = quant_ste.PropagateQuantTensorSTE(function.RoundG.apply)
+round_sym_g = quant_ste.PropagateQuantTensorSTE(function.RoundSymG.apply)
+round_up_g = quant_ste.PropagateQuantTensorSTE(function.RoundUpG.apply)
+round2_g = quant_ste.PropagateQuantTensorSTE(function.Round2G.apply)
+ceil_g = quant_ste.PropagateQuantTensorSTE(function.CeilG.apply)
+ceil2_g = quant_ste.PropagateQuantTensorSTE(function.Ceil2G.apply)
+
+# This line with PropagateQuantTensorSTE is optional: using PropagateQuantTensorSTE will cause
+# backward method of QuantizeDequantizeG to be skipped
+# Replace with PropagateQuantTensorQTE to: allow gradient to flow back through the backward method
+# Note: QTE here has effect only if QTE is used in forward() QuantTrainPAct2 in quant_train_modules.py
+# by using quantize_backward_type = 'qte' in QuantTrainPAct2
+# TODO: when using QTE here, we need to register this OP for ONNX export to work
+# and even then the exported model may not be clean.
+quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.QuantizeDequantizeG.apply)
 
 
 ###################################################