]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/layers/activation.py
re-implemented QuantTestModule using QuantTrainModule. constrain_bias added.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
index 325764e0caaea4cc4fb3dbe87ba9283f709fb41a..a1e6ee68af0dab1a6891e821eea07959b237b5a8 100644 (file)
@@ -15,7 +15,7 @@ class PAct2(torch.nn.Module):
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
 
-    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, **kwargs):
+    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
@@ -27,7 +27,7 @@ class PAct2(torch.nn.Module):
         self.fixed_range = (clip_range is not None)
         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
-        self.power2 = True   # power of 2 ranges
+        self.power2_activation_range = power2_activation_range   # power of 2 ranges
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
 
         # any validation before at-least one iteration of training wll use default clip values.
@@ -54,8 +54,8 @@ class PAct2(torch.nn.Module):
         #
 
 
-    def forward(self, x, update_range=True, enable=True):
-        if (self.training and update_range):
+    def forward(self, x, update_activation_range=True, enable=True):
+        if (self.training and update_activation_range):
             self.num_batches_tracked += 1
             # even in learn_range mode - do this for a few iterations to get a good starting point
             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
@@ -66,10 +66,11 @@ class PAct2(torch.nn.Module):
         #
         if not enable:
             signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
-            return x if signed else torch.nn.functional.relu(x)
+            y = x if signed else torch.nn.functional.relu(x)
+        else:
+            clips = self.get_clips_act()
+            y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
         #
-        clips = self.get_clips_act()
-        y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
         return y
 
 
@@ -110,9 +111,10 @@ class PAct2(torch.nn.Module):
         clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
         clip_max = torch.clamp(clip_max, min=self.eps)
         clip_max = self.convert_to_linear(clip_max)
-        # in range learning mode + training - this power2 is taken care in the quantize function
-        use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
-        clip_max2 = ceil2_g(clip_max) if use_power2 else clip_max
+        # in range learning mode + training - this power2_activation_range is taken care in the quantize function
+        is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+        use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
+        clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
         return (clip_min2, clip_max2)
 
@@ -136,7 +138,22 @@ class ReLU1(torch.nn.Hardtanh):
 
 
 ###############################################################
-class NoAct(torch.nn.Module):
+# Always quantized activation function.
+# Inserting this activation function is a simple way to ensure quantization happens at certain places.
+class QAct(torch.nn.Module):
+    def __init__(self, inplace=False, signed=True, **kwargs):
+        super().__init__()
+        self.inplace = inplace
+        self.signed = signed
+
+    def forward(self, x):
+        return x
+
+
+# Never quantized activation function.
+# Also if the next block is this, the previous block output is also not quantized.
+# Inserting this activation function is a simple way to avoid quantization at certain places.
+class NoQAct(torch.nn.Module):
     def __init__(self, inplace=False, signed=True, **kwargs):
         super().__init__()
         self.inplace = inplace