[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
1 import numpy as np
2 import torch
4 from .functional import *
5 from .. import utils
8 ###############################################################
9 # Parametric Activation (PACT) with clip values being power of 2
10 # Supports learned mode, estimated mode or fixed range
11 class PAct2(torch.nn.Module):
12 # constants
13 PACT2_RANGE_LEARN = False # False : Running Avg, True : Backprop
14 PACT2_RANGE_SHRINK = 0.01 # 0.01
15 PACT2_RANGE_INIT = 8.0 # this is the starting range
16 PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
18 def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None,
19 power2_activation_range=True, **kwargs):
20 super().__init__()
21 if (clip_range is not None) and (signed is not None):
22 assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
23 #
24 self.inplace = inplace
25 self.clip_range = clip_range
26 self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
27 self.percentile_range_shrink = percentile_range_shrink # range shrinking factor
28 self.fixed_range = (clip_range is not None)
29 self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
30 self.eps = np.power(2.0, -16.0)
31 self.power2_activation_range = power2_activation_range # power of 2 ranges
32 self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
34 # any validation before at-least one iteration of training wll use default clip values.
35 clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
36 clip_init2 = np.power(2.0, np.ceil(np.log2(clip_init)))
38 if self.learn_range:
39 clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
40 default_clips = (-clip_signed_log, clip_signed_log) \
41 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
42 self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips)))
43 # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
44 # Then the backprop updates will have dominance.
45 self.range_update_factor_min = 0.0
46 self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
47 else:
48 default_clips = (-clip_init2, clip_init2) \
49 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
50 self.register_buffer('clips_act', torch.tensor(default_clips))
51 # range_update_factor_min is the lower bound for exponential update factor.
52 # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
53 self.range_update_factor_min = 0.001
54 self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
55 #
58 def forward(self, x, update_activation_range=True, enable=True):
59 if (self.training and update_activation_range):
60 self.num_batches_tracked += 1
61 # even in learn_range mode - do this for a few iterations to get a good starting point
62 if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
63 with torch.no_grad():
64 self.update_scale_act(x.data)
65 #
66 #
67 #
68 if not enable:
69 signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
70 y = x if signed else torch.nn.functional.relu(x)
71 else:
72 clips = self.get_clips_act()
73 y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
74 #
75 return y
78 def __repr__(self):
79 clips = self.get_clips_act()
80 return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
83 def convert_to_log(self, x):
84 if (not self.learn_range) or (self.log_base is None):
85 return x
86 #
87 return utils.signed_log(x, self.log_base)
90 def convert_to_linear(self, x):
91 if (not self.learn_range) or (self.log_base is None):
92 return x
93 #
94 return utils.signed_pow(x, self.log_base)
97 def update_scale_act(self, x):
98 # compute the new scale
99 x_min, x_max = utils.extrema_fast(x, percentile_range_shrink=self.percentile_range_shrink)
100 x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
101 # exponential update factor
102 update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
103 update_factor = max(update_factor, self.range_update_factor_min)
104 # exponential moving average update
105 self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
106 self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
109 def get_clips_act(self):
110 # find the clip values
111 signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
112 clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
113 clip_max = torch.clamp(clip_max, min=self.eps)
114 clip_max = self.convert_to_linear(clip_max)
115 # in range learning mode + training - this power2_activation_range is taken care in the quantize function
116 is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
117 use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
118 clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
119 clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
120 return (clip_min2, clip_max2)
123 ###############################################################
124 # return a function that creates PAct2 with the given fixed range
125 # remember: this function returns a type and not an instance
126 def get_fixed_pact2_type(inplace=False, signed=None, output_range=None):
127 def FixedPAct2Type(inplace=inplace, signed=signed):
128 assert output_range is not None, 'output_range must be specified for FixedPact2'
129 clip_range = output_range #max(abs(np.array(output_range)))
130 signed = True if ((output_range[0] < 0.0) or (signed is True)) else signed
131 return PAct2(inplace=inplace, signed=signed, clip_range=clip_range)
132 #
133 return FixedPAct2Type
136 ###############################################################
137 # return a derivative of Hardtanh with the given fixed range
138 # remember: this function returns a type and not an instance
139 def get_fixed_hardtanh_type(*args, **kwargs):
140 class FixedHardtanhType(torch.nn.Hardtanh):
141 def __init__(self, *args_, **kwargs_):
142 super().__init__(*args, **kwargs)
143 #
144 return FixedHardtanhType
147 ###############################################################
148 class ReLU1(torch.nn.Hardtanh):
149 def __init__(self, min_val=0., max_val=1., inplace=False):
150 super().__init__(min_val=min_val, max_val=max_val, inplace=inplace)
153 ###############################################################
154 # Always quantized activation function.
155 # Inserting this activation function is a simple way to ensure quantization happens at certain places.
156 class QAct(torch.nn.Module):
157 def __init__(self, inplace=False, signed=True, **kwargs):
158 super().__init__()
159 self.inplace = inplace
160 self.signed = signed
162 def forward(self, x):
163 return x
166 # Never quantized activation function.
167 # Also if the next block is this, the previous block output is also not quantized.
168 # Inserting this activation function is a simple way to avoid quantization at certain places.
169 class NoQAct(torch.nn.Module):
170 def __init__(self, inplace=False, signed=True, **kwargs):
171 super().__init__()
172 self.inplace = inplace
173 self.signed = signed
175 def forward(self, x):
176 return x