[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/activation.py b/modules/pytorch_jacinto_ai/xnn/layers/activation.py
index d89e6eea496517cbc2c320905e891abf2ba20b15..5e6a7c6c4f4419d15846e4bfeb725cb038c3fa14 100644 (file)
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
import numpy as np
import torch
from .. import utils
-###############################################################
# Parametric Activation (PACT) with clip values being power of 2
# Supports learned mode, estimated mode or fixed range
class PAct2(torch.nn.Module):
- # constants
- PACT2_RANGE_LEARN = False # False : Running Avg, True : Backprop
- PACT2_RANGE_SHRINK = 0.01 # 0.01
- PACT2_RANGE_INIT = 8.0 # this is the starting range
- PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
+ # constants - default/init values
+ PACT2_RANGE_LEARN_MODE = False # False : Running Avg, True : Backprop
+ PACT2_RANGE_SHRINK_DEFAULT = 0.01 # 0.01
+ PACT2_RANGE_INIT = 8.0 # this is the starting range
+ PACT2_RANGE_EXPANSION_FACTOR = 1.0 # expand the calculated range for margin
- def __init__(self, inplace=False, signed=None, range_shrink_percentile=PACT2_RANGE_SHRINK, clip_range=None,
+ def __init__(self, inplace=False, signed=None, range_shrink_activations=PACT2_RANGE_SHRINK_DEFAULT, clip_range=None,
power2_activation_range=True, **kwargs):
super().__init__()
if (clip_range is not None) and (signed is not None):
self.inplace = inplace
self.clip_range = clip_range
self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
- self.range_shrink_percentile = range_shrink_percentile # range shrinking factor
+ self.range_shrink_activations = range_shrink_activations # range shrinking factor
self.fixed_range = (clip_range is not None)
- self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
+ self.learn_range = (self.PACT2_RANGE_LEARN_MODE and (not self.fixed_range))
self.eps = np.power(2.0, -16.0)
self.power2_activation_range = power2_activation_range # power of 2 ranges
self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
if utils.has_range_estimator:
- self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_percentile,
+ self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_activations,
range_update_factor_min=self.range_update_factor_min)
#
#
return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
+ def freeze_range(self):
+ self.fixed_range = True
+
+
def convert_to_log(self, x):
if (not self.learn_range) or (self.log_base is None):
return x
def update_clips_act(self, x):
if self.learn_range or (self.range_estimator is None):
- x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_percentile)
- x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
+ x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_activations)
+ x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION_FACTOR), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION_FACTOR)
# exponential update factor
update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
update_factor = max(update_factor, self.range_update_factor_min)
clip_max = torch.clamp(clip_max, min=self.eps)
clip_max = self.convert_to_linear(clip_max)
# in range learning mode + training - this power2_activation_range is taken care in the quantize function
- is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+ is_learning_range = (self.PACT2_RANGE_LEARN_MODE and self.training)
use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
return x
+###############################################################
# Never quantized activation function.
# Also if the next block is this, the previous block output is also not quantized.
# Inserting this activation function is a simple way to avoid quantization at certain places.
self.signed = signed
def forward(self, x):
- return x
\ No newline at end of file
+ return x
+
+
+###############################################################
+def freeze_quant_range(module):
+ def _freeze_range_op(op):
+ if isinstance(op, PAct2):
+ op.freeze_range()
+ #
+ #
+ module.apply(_freeze_range_op)
+ module.apply(torch.quantization.disable_observer)