d89e6eea496517cbc2c320905e891abf2ba20b15
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
1 import numpy as np
2 import torch
4 from .functional import *
5 from .. import utils
8 ###############################################################
9 # Parametric Activation (PACT) with clip values being power of 2
10 # Supports learned mode, estimated mode or fixed range
11 class PAct2(torch.nn.Module):
12 # constants
13 PACT2_RANGE_LEARN = False # False : Running Avg, True : Backprop
14 PACT2_RANGE_SHRINK = 0.01 # 0.01
15 PACT2_RANGE_INIT = 8.0 # this is the starting range
16 PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
18 def __init__(self, inplace=False, signed=None, range_shrink_percentile=PACT2_RANGE_SHRINK, clip_range=None,
19 power2_activation_range=True, **kwargs):
20 super().__init__()
21 if (clip_range is not None) and (signed is not None):
22 assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
23 #
24 self.inplace = inplace
25 self.clip_range = clip_range
26 self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
27 self.range_shrink_percentile = range_shrink_percentile # range shrinking factor
28 self.fixed_range = (clip_range is not None)
29 self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
30 self.eps = np.power(2.0, -16.0)
31 self.power2_activation_range = power2_activation_range # power of 2 ranges
32 self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
33 self.range_estimator = None
35 # any validation before at-least one iteration of training wll use default clip values.
36 clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
37 clip_init2 = np.power(2.0, np.ceil(np.log2(clip_init)))
39 if self.learn_range:
40 clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
41 default_clips = (-clip_signed_log, clip_signed_log) \
42 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
43 self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips, dtype=torch.float32)))
44 # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
45 # Then the backprop updates will have dominance.
46 self.range_update_factor_min = 0.0
47 self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
48 else:
49 default_clips = (-clip_init2, clip_init2) \
50 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
51 self.register_buffer('clips_act', torch.tensor(default_clips, dtype=torch.float32))
52 # range_update_factor_min is the lower bound for exponential update factor.
53 # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
54 self.range_update_factor_min = 0.001
55 self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
57 if utils.has_range_estimator:
58 self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_percentile,
59 range_update_factor_min=self.range_update_factor_min)
60 #
61 #
64 def forward(self, x, update_activation_range=True, enable=True):
65 if (self.training and update_activation_range):
66 self.num_batches_tracked += 1
67 # even in learn_range mode - do this for a few iterations to get a good starting point
68 if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
69 with torch.no_grad():
70 self.update_clips_act(x.data)
71 #
72 #
73 #
74 if not enable:
75 signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
76 y = x if signed else torch.nn.functional.relu(x)
77 else:
78 clips = self.get_clips_act()
79 y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
80 #
81 return y
84 def __repr__(self):
85 clips = self.get_clips_act()
86 return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
89 def convert_to_log(self, x):
90 if (not self.learn_range) or (self.log_base is None):
91 return x
92 #
93 return utils.signed_log(x, self.log_base)
96 def convert_to_linear(self, x):
97 if (not self.learn_range) or (self.log_base is None):
98 return x
99 #
100 return utils.signed_pow(x, self.log_base)
103 def update_clips_act(self, x):
104 if self.learn_range or (self.range_estimator is None):
105 x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_percentile)
106 x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
107 # exponential update factor
108 update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
109 update_factor = max(update_factor, self.range_update_factor_min)
110 # exponential moving average update
111 self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
112 self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
113 else:
114 mn, mx = self.range_estimator(x)
115 self.clips_act[0].data.fill_(mn)
116 self.clips_act[1].data.fill_(mx)
117 #
120 def get_clips_act(self):
121 # find the clip values
122 signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
123 clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
124 clip_max = torch.clamp(clip_max, min=self.eps)
125 clip_max = self.convert_to_linear(clip_max)
126 # in range learning mode + training - this power2_activation_range is taken care in the quantize function
127 is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
128 use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
129 clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
130 clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
131 return (clip_min2, clip_max2)
134 ###############################################################
135 # return a function that creates PAct2 with the given fixed range
136 # remember: this function returns a type and not an instance
137 def get_fixed_pact2_type(inplace=False, signed=None, output_range=None):
138 def FixedPAct2Type(inplace=inplace, signed=signed):
139 assert output_range is not None, 'output_range must be specified for FixedPact2'
140 clip_range = output_range #max(abs(np.array(output_range)))
141 signed = True if ((output_range[0] < 0.0) or (signed is True)) else signed
142 return PAct2(inplace=inplace, signed=signed, clip_range=clip_range)
143 #
144 return FixedPAct2Type
147 ###############################################################
148 # return a derivative of Hardtanh with the given fixed range
149 # remember: this function returns a type and not an instance
150 def get_fixed_hardtanh_type(*args, **kwargs):
151 class FixedHardtanhType(torch.nn.Hardtanh):
152 def __init__(self, *args_, **kwargs_):
153 super().__init__(*args, **kwargs)
154 #
155 return FixedHardtanhType
158 ###############################################################
159 class ReLU1(torch.nn.Hardtanh):
160 def __init__(self, min_val=0., max_val=1., inplace=False):
161 super().__init__(min_val=min_val, max_val=max_val, inplace=inplace)
164 ###############################################################
165 # Always quantized activation function.
166 # Inserting this activation function is a simple way to ensure quantization happens at certain places.
167 class QAct(torch.nn.Module):
168 def __init__(self, inplace=False, signed=True, **kwargs):
169 super().__init__()
170 self.inplace = inplace
171 self.signed = signed
173 def forward(self, x):
174 return x
177 # Never quantized activation function.
178 # Also if the next block is this, the previous block output is also not quantized.
179 # Inserting this activation function is a simple way to avoid quantization at certain places.
180 class NoQAct(torch.nn.Module):
181 def __init__(self, inplace=False, signed=True, **kwargs):
182 super().__init__()
183 self.inplace = inplace
184 self.signed = signed
186 def forward(self, x):
187 return x