[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/activation.py b/modules/pytorch_jacinto_ai/xnn/layers/activation.py
index 325764e0caaea4cc4fb3dbe87ba9283f709fb41a..a1e6ee68af0dab1a6891e821eea07959b237b5a8 100644 (file)
PACT2_RANGE_INIT = 8.0 # this is the starting range
PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
- def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, **kwargs):
+ def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, power2_activation_range=True, **kwargs):
super().__init__()
if (clip_range is not None) and (signed is not None):
assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
self.fixed_range = (clip_range is not None)
self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
self.eps = np.power(2.0, -16.0)
- self.power2 = True # power of 2 ranges
+ self.power2_activation_range = power2_activation_range # power of 2 ranges
self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
# any validation before at-least one iteration of training wll use default clip values.
#
- def forward(self, x, update_range=True, enable=True):
- if (self.training and update_range):
+ def forward(self, x, update_activation_range=True, enable=True):
+ if (self.training and update_activation_range):
self.num_batches_tracked += 1
# even in learn_range mode - do this for a few iterations to get a good starting point
if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
#
if not enable:
signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
- return x if signed else torch.nn.functional.relu(x)
+ y = x if signed else torch.nn.functional.relu(x)
+ else:
+ clips = self.get_clips_act()
+ y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
#
- clips = self.get_clips_act()
- y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
return y
clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
clip_max = torch.clamp(clip_max, min=self.eps)
clip_max = self.convert_to_linear(clip_max)
- # in range learning mode + training - this power2 is taken care in the quantize function
- use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
- clip_max2 = ceil2_g(clip_max) if use_power2 else clip_max
+ # in range learning mode + training - this power2_activation_range is taken care in the quantize function
+ is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+ use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
+ clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
return (clip_min2, clip_max2)
###############################################################
-class NoAct(torch.nn.Module):
+# Always quantized activation function.
+# Inserting this activation function is a simple way to ensure quantization happens at certain places.
+class QAct(torch.nn.Module):
+ def __init__(self, inplace=False, signed=True, **kwargs):
+ super().__init__()
+ self.inplace = inplace
+ self.signed = signed
+
+ def forward(self, x):
+ return x
+
+
+# Never quantized activation function.
+# Also if the next block is this, the previous block output is also not quantized.
+# Inserting this activation function is a simple way to avoid quantization at certain places.
+class NoQAct(torch.nn.Module):
def __init__(self, inplace=False, signed=True, **kwargs):
super().__init__()
self.inplace = inplace