23965eb49caf3385b2a6a2d09cfab2e3d824a146
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_utils.py
1 import warnings
2 import numpy as np
3 import torch
4 from .. import utils
5 from .. import layers
6 from .quant_base_module import *
7 from .quant_utils import *
9 ###########################################################
10 class QuantTrainParams:
11 pass
14 def get_qparams():
15 qparams = QuantTrainParams()
16 qparams.inputs = []
17 qparams.modules = []
18 return qparams
21 def is_merged_layer(x):
22 is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
23 return is_merged
26 ###########################################################
27 class QuantTrainConv2d(torch.nn.Conv2d):
28 def __init__(self, *args, **kwargs):
29 super().__init__(*args, **kwargs)
30 self.quantize_enable = True
31 self.bitwidth_weights = None
32 self.bitwidth_activations = None
33 self.per_channel_q = False
35 def forward(self, x):
36 is_merged = is_merged_layer(x)
37 if is_merged:
38 warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
39 #
41 y = super().forward(x)
43 if not self.quantize_enable:
44 # if quantization is disabled - return
45 return y
46 #
48 qparams = get_qparams()
49 qparams.inputs.append(x)
50 qparams.modules.append(self)
51 y.qparams = qparams
52 #
53 return y
54 #
57 ###########################################################
58 class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
59 def __init__(self, *args, **kwargs):
60 super().__init__(*args, **kwargs)
61 self.quantize_enable = True
62 self.bitwidth_weights = None
63 self.bitwidth_activations = None
64 self.per_channel_q = False
66 def forward(self, x):
67 is_merged = is_merged_layer(x)
68 if is_merged:
69 warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
70 #
72 y = super().forward(x)
74 if not self.quantize_enable:
75 # if quantization is disabled - return
76 return y
77 #
79 qparams = get_qparams()
80 qparams.inputs.append(x)
81 qparams.modules.append(self)
82 y.qparams = qparams
83 #
84 return y
85 #
88 ###########################################################
89 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
90 def __init__(self, *args, **kwargs):
91 super().__init__(*args, **kwargs)
92 self.quantize_enable = True
95 def forward(self, x):
96 y = super().forward(x)
98 if not self.quantize_enable:
99 # if quantization is disabled - return
100 return y
101 #
103 if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
104 qparams = get_qparams()
105 qparams.inputs = [x.qparams.inputs[0], x]
106 qparams.modules = [x.qparams.modules[0], self]
107 y.qparams = qparams
108 #
110 return y
111 #
114 ###########################################################
115 # fake quantized PAct2 for training
116 class QuantTrainPAct2(layers.PAct2):
117 def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
118 super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
120 self.bitwidth_weights = bitwidth_weights
121 self.bitwidth_activations = bitwidth_activations
122 self.per_channel_q = per_channel_q
123 # weight shrinking is done by clamp weights - set this factor to zero.
124 # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
125 # so any clipping we do here is not stored int he weight params
126 self.range_shrink_weights = 0.0
127 self.round_dither = 0.0
128 self.update_range = True
129 self.quantize_enable = True
130 self.quantize_weights = True
131 self.quantize_bias = True
132 self.quantize_activations = True
133 self.constrain_weights = True
134 self.bias_calibration = False
135 # save quantized weight/bias once in a while into the params - not needed
136 self.params_save_frequency = None #(10 if self.bias_calibration else None)
138 # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
139 # For a comparison of STE and ABE, read:
140 # Learning low-precision neural networks without Straight-Through Estimator (STE):
141 # https://arxiv.org/pdf/1903.01061.pdf
142 self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
143 self.alpha_blending_estimation_factor = 0.5
145 if (layers.PAct2.PACT2_RANGE_LEARN):
146 assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
147 'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
148 #
151 def forward(self, x):
152 assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
153 'bitwidth_weights and bitwidth_activations must not be None'
155 # the pact range update happens here - but range clipping depends on quantize_enable
156 y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
158 if not self.quantize_enable:
159 return y
160 #
162 # previous intermediate outputs and other infoirmation are avaliable
163 # for example - conv-bn-relu may need to be merged together.
164 is_merged = is_merged_layer(x)
165 if is_merged:
166 qparams = x.qparams
167 xorg = qparams.inputs[0]
168 conv, weight, bias = self.merge_quantize_weights(qparams, is_merged)
169 else:
170 conv, weight, bias = None, None, None
171 #
173 if is_merged and utils.is_conv(conv):
174 xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
175 elif is_merged and utils.is_deconv(conv):
176 xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
177 dilation=conv.dilation, groups=conv.groups)
178 else:
179 xq = x
180 #
182 if (self.quantize_enable and self.quantize_activations):
183 clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
184 width_min, width_max = self.get_widths_act()
185 # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
186 # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
187 # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
188 # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
189 yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
190 else:
191 yq = super().forward(xq, update_range=False, enable=True)
192 #
194 if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
195 # replace the float output with quantized version
196 # the entire weight merging and quantization process is bypassed in the forward pass
197 # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
198 with torch.no_grad():
199 y.data.copy_(yq.data)
200 #
201 elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
202 # TODO: vary the alpha blending factor over the epochs
203 y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
204 elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
205 # pass on the quantized output - the backward gradients also flow through quantization.
206 # however, note the gradients of round and ceil operators are forced to be unity (1.0).
207 y = yq
208 else:
209 assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
210 #
212 return y
213 #
215 def apply_constrain_weights(self, merged_weight):
216 return constrain_weight(merged_weight)
219 def merge_quantize_weights(self, qparams, is_merged):
220 # store the quantized weights and biases in-frequently - otherwise learning will be poor
221 # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
222 first_training_iter = self.training and (self.num_batches_tracked == 0)
223 is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
225 conv, bn = None, None
226 # merge weight and bias (if possible) across layers
227 if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
228 conv = qparams.modules[-2]
229 conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
230 #
231 bn = qparams.modules[-1]
232 bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
233 bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
234 #
235 merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
236 merged_bias = (conv_bias - bn.running_mean) * merged_scale + bn_bias
237 merged_weight = conv.weight * merged_scale.view(-1, 1, 1, 1)
238 #
239 merged_scale_sign = merged_scale.sign()
240 merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
241 merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
242 merged_scale_inv = 1.0 / merged_scale_eps
243 #
244 elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
245 conv = qparams.modules[-1]
246 merged_weight = conv.weight
247 merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
248 merged_scale = torch.ones(conv.out_channels).to(conv.weight.device)
249 merged_scale_inv = torch.ones(conv.out_channels).to(conv.weight.device)
250 elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
251 assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
252 bn = qparams.modules[-1]
253 merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
254 merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
255 else:
256 assert False, f'quantization: previous layer is unrecognized {qparams.modules} - prease inspect the model carefully'
257 merged_weight = 0.0
258 merged_bias = 0.0
259 #
261 # quantize weight and bias
262 if (conv is not None):
263 if (self.quantize_enable and self.quantize_weights):
264 if self.constrain_weights and first_training_iter:
265 with torch.no_grad():
266 # clamp merged weights, invert the bn and copy to conv weight
267 constrained_weight = self.apply_constrain_weights(merged_weight.data)
268 merged_weight.data.copy_(constrained_weight.data)
269 # store clipped weight after inverting bn
270 conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
271 #
272 #
274 is_dw = utils.is_dwconv(conv)
275 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
276 if use_per_channel_q:
277 channels = int(merged_weight.size(0))
278 scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
279 scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
280 for chan_id in range(channels):
281 clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
282 scale2[chan_id,0,0,0] = scale2_value
283 scale_inv2[chan_id,0,0,0] = scale_inv2_value
284 #
285 else:
286 clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
287 #
288 width_min, width_max = self.get_widths_w()
289 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
290 merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
291 #
293 if (self.quantize_enable and self.quantize_bias):
294 bias_width_min, bias_width_max = self.get_widths_bias()
295 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
296 # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
297 merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
298 #
300 # invert the bn operation and store weights/bias
301 if self.training and is_store_weight_bias_iter:
302 with torch.no_grad():
303 if self.quantize_enable and self.quantize_weights:
304 conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
305 #
306 if self.quantize_enable and self.quantize_bias:
307 if conv.bias is not None:
308 if bn is not None:
309 conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
310 conv.bias.data.copy_(conv_bias.data)
311 else:
312 conv.bias.data.copy_(merged_bias.data)
313 #
314 elif bn is not None and bn.bias is not None:
315 bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
316 bn.bias.data.copy_(bn_bias.data)
317 #
318 #
319 #
320 #
321 #
322 return conv, merged_weight, merged_bias
325 def get_clips_w(self, tensor):
326 # find the clip values
327 w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
328 clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
329 clip_max = torch.clamp(clip_max, min=self.eps)
330 # in range learning mode + training - this power2 is taken care in the quantize function
331 use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
332 clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
333 clip_min2 = -clip_max2
334 return (clip_min2, clip_max2)
336 # bias uses the same kind of clips
337 get_clips_bias = get_clips_w
340 def get_clips_scale_w(self, weight):
341 # convert to scale
342 clip_min, clip_max = self.get_clips_w(weight)
343 width_min, width_max = self.get_widths_w()
344 scale2 = (width_max / clip_max)
345 scale2 = torch.clamp(scale2, min=self.eps)
346 scale_inv2 = scale2.pow(-1.0)
347 return (clip_min, clip_max, scale2, scale_inv2)
350 # in reality, bias quantization will also depend on the activation scale
351 # this is not perfect - just a quick and dirty quantization for bias
352 def get_clips_scale_bias(self, bias):
353 # convert to scale
354 clip_min, clip_max = self.get_clips_bias(bias)
355 width_min, width_max = self.get_widths_bias()
356 scale2 = (width_max / clip_max)
357 scale2 = torch.clamp(scale2, min=self.eps)
358 scale_inv2 = scale2.pow(-1.0)
359 return (clip_min, clip_max, scale2, scale_inv2)
362 def get_widths_w(self):
363 # weights
364 bw = (self.bitwidth_activations - 1)
365 width_max = np.power(2.0, bw)
366 width_min = -width_max
367 # return
368 return (width_min, width_max)
371 def get_widths_bias(self):
372 # bias
373 bitwidth_bias = (2*self.bitwidth_activations)
374 bias_width_max = np.power(2.0, bitwidth_bias-1)
375 bias_width_min = -bias_width_max
376 # return
377 return (bias_width_min, bias_width_max)
380 # activation utility functions
381 def get_clips_scale_act(self):
382 # convert to scale
383 clip_min, clip_max = self.get_clips_act()
384 width_min, width_max = self.get_widths_act()
385 scale2 = width_max / clip_max
386 scale2 = torch.clamp(scale2, min=self.eps)
387 scale_inv2 = scale2.pow(-1.0)
388 return (clip_min, clip_max, scale2, scale_inv2)
391 def get_widths_act(self):
392 if self.signed is None:
393 clip_min, clip_max = self.get_clips_act()
394 signed = (clip_min < 0.0)
395 else:
396 signed = self.signed
397 #
398 bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
399 width_max = np.power(2.0, bw)
400 width_min = -width_max if signed else 0.0
401 return width_min, width_max