constrain_weights - in default scenario, this is not used with per_channel_q.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from . import quant_utils
14 from .quant_base_module import *
16 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
19 ###########################################################
20 class QuantTrainModule(QuantBaseModule):
21     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
22                  histogram_range=True, bias_calibration=False, constrain_weights=None):
23         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
24         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
25                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
26                          constrain_weights=constrain_weights, model_surgery_quantize=True)
27         # range shrink - 0.0 indicates no shrink
28         percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
29         # set attributes to all modules - can control the behaviour from here
30         utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
31                             per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
32                             percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
33                             update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
35     def forward(self, inputs):
36         # counters such as num_batches_tracked are used. update them.
37         self.update_counters()
38         # outputs
39         outputs = self.module(inputs)
40         return outputs
43     def model_surgery_quantize(self, dummy_input):
44         super().model_surgery_quantize(dummy_input)
46         def replace_func(op):
47             for name, m in op._modules.items():
48                 if utils.is_conv(m):
49                     bias = (m.bias is not None)
50                     padding_mode = m.padding_mode
51                     new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
52                                             padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
53                 elif utils.is_deconv(m):
54                     bias = (m.bias is not None)
55                     padding_mode = m.padding_mode
56                     new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
57                                             padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
58                 elif utils.is_bn(m):
59                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
60                                             track_running_stats=m.track_running_stats)
61                 elif isinstance(m, layers.PAct2):
62                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
63                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
64                                              per_channel_q=self.per_channel_q)
65                 elif isinstance(m, layers.NoAct):
66                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
67                                              per_channel_q=self.per_channel_q)
68                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
69                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
70                                              per_channel_q=self.per_channel_q)
71                 else:
72                     new_m = None
73                 #
74                 if new_m is not None:
75                     for attr in dir(m):
76                         value = getattr(m,attr)
77                         if isinstance(value,torch.Tensor) and value is not None:
78                             getattr(new_m,attr).data.copy_(value.data)
79                         elif isinstance(value,torch.nn.Module) and value is not None:
80                             setattr(new_m, attr, getattr(m,attr))
81                         elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
82                             setattr(new_m, attr, getattr(m, attr))
83                         #
84                     #
85                     new_m.train(m.training)
86                     setattr(op, name, new_m)
87                 #
88             #
89         #
90         # apply recursively
91         self.apply(replace_func)
93         # clear
94         self.clear_states()
95     #
100 ###########################################################
101 class QuantTrainParams:
102     pass
105 def get_qparams():
106     qparams = QuantTrainParams()
107     qparams.inputs = []
108     qparams.modules = []
109     return qparams
112 def is_merged_layer(x):
113     is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
114     return is_merged
117 ###########################################################
118 class QuantTrainConv2d(torch.nn.Conv2d):
119     def __init__(self, *args, **kwargs):
120         super().__init__(*args, **kwargs)
121         self.quantize_enable = True
122         self.bitwidth_weights = None
123         self.bitwidth_activations = None
124         self.per_channel_q = False
126     def forward(self, x):
127         is_merged = is_merged_layer(x)
128         if is_merged:
129            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
130         #
132         y = super().forward(x)
134         if not self.quantize_enable:
135             # if quantization is disabled - return
136             return y
137         #
139         qparams = get_qparams()
140         qparams.inputs.append(x)
141         qparams.modules.append(self)
142         y.qparams = qparams
143         #
144         return y
145     #
148 ###########################################################
149 class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
150     def __init__(self, *args, **kwargs):
151         super().__init__(*args, **kwargs)
152         self.quantize_enable = True
153         self.bitwidth_weights = None
154         self.bitwidth_activations = None
155         self.per_channel_q = False
157     def forward(self, x):
158         is_merged = is_merged_layer(x)
159         if is_merged:
160            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
161         #
163         y = super().forward(x)
165         if not self.quantize_enable:
166             # if quantization is disabled - return
167             return y
168         #
170         qparams = get_qparams()
171         qparams.inputs.append(x)
172         qparams.modules.append(self)
173         y.qparams = qparams
174         #
175         return y
176     #
179 ###########################################################
180 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
181     def __init__(self, *args, **kwargs):
182         super().__init__(*args, **kwargs)
183         self.quantize_enable = True
186     def forward(self, x):
187         y = super().forward(x)
189         if not self.quantize_enable:
190             # if quantization is disabled - return
191             return y
192         #
194         if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
195             qparams = get_qparams()
196             qparams.inputs = [x.qparams.inputs[0], x]
197             qparams.modules = [x.qparams.modules[0], self]
198             y.qparams = qparams
199         #
201         return y
202     #
205 ###########################################################
206 # fake quantized PAct2 for training
207 class QuantTrainPAct2(layers.PAct2):
208     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
209         super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
211         self.bitwidth_weights = bitwidth_weights
212         self.bitwidth_activations = bitwidth_activations
213         self.per_channel_q = per_channel_q
214         # weight shrinking is done by clamp weights - set this factor to zero.
215         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
216         # so any clipping we do here is not stored int he weight params
217         self.range_shrink_weights = 0.0
218         self.round_dither = 0.0
219         self.update_range = True
220         self.quantize_enable = True
221         self.quantize_weights = True
222         self.quantize_bias = True
223         self.quantize_activations = True
224         self.constrain_weights = True
225         self.bias_calibration = False
226         # save quantized weight/bias once in a while into the params - not needed
227         self.params_save_frequency = None #(10 if self.bias_calibration else None)
229         # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
230         # For a comparison of STE and ABE, read:
231         # Learning low-precision neural networks without Straight-Through Estimator (STE):
232         # https://arxiv.org/pdf/1903.01061.pdf
233         self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
234         self.alpha_blending_estimation_factor = 0.5
236         if (layers.PAct2.PACT2_RANGE_LEARN):
237             assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
238                 'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
239         #
242     def forward(self, x):
243         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
244                         'bitwidth_weights and bitwidth_activations must not be None'
246         # the pact range update happens here - but range clipping depends on quantize_enable
247         y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
249         if not self.quantize_enable:
250             return y
251         #
253         # previous intermediate outputs and other infoirmation are avaliable
254         # for example - conv-bn-relu may need to be merged together.
255         is_merged = is_merged_layer(x)
256         if is_merged:
257             qparams = x.qparams
258             xorg = qparams.inputs[0]
260             conv, bn = None, None
261             # merge weight and bias (if possible) across layers
262             if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
263                     qparams.modules[-1], torch.nn.BatchNorm2d):
264                 conv = qparams.modules[-2]
265                 bn = qparams.modules[-1]
266             elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
267                 conv = qparams.modules[-1]
268             elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
269                 assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
270                 bn = qparams.modules[-1]
271             #
272             else:
273                 assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
274             #
275             conv, weight, bias = self.merge_quantize_weights(conv, bn)
276         else:
277             conv, weight, bias = None, None, None
278         #
280         if is_merged and utils.is_conv(conv):
281             xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
282         elif is_merged and utils.is_deconv(conv):
283             xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
284                                                       dilation=conv.dilation, groups=conv.groups)
285         else:
286             xq = x
287         #
289         if (self.quantize_enable and self.quantize_activations):
290             clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
291             width_min, width_max = self.get_widths_act()
292             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
293             # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
294             # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
295             # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
296             yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
297         else:
298             yq = super().forward(xq, update_range=False, enable=True)
299         #
301         if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
302             # replace the float output with quantized version
303             # the entire weight merging and quantization process is bypassed in the forward pass
304             # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
305             with torch.no_grad():
306                 y.data.copy_(yq.data)
307             #
308         elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
309             # TODO: vary the alpha blending factor over the epochs
310             y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
311         elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
312             # pass on the quantized output - the backward gradients also flow through quantization.
313             # however, note the gradients of round and ceil operators are forced to be unity (1.0).
314             y = yq
315         else:
316             assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
317         #
319         return y
320     #
323     def apply_constrain_weights(self, merged_weight):
324         return quant_utils.constrain_weight(merged_weight)
327     def merge_quantize_weights(self, conv, bn):
328         # store the quantized weights and biases in-frequently - otherwise learning will be poor
329         # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
330         first_training_iter = self.training and (self.num_batches_tracked == 0)
331         is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
333         # merge weight and bias (if possible) across layers
334         if conv is not None and bn is not None:
335             conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
336             #
337             bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
338             bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
339             #
340             merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
341             if utils.is_conv(conv):
342                 merged_scale = merged_scale.view(-1, 1, 1, 1)
343             elif utils.is_deconv(conv):
344                 merged_scale = merged_scale.view(1, -1, 1, 1)
345             else:
346                 assert False, 'unable to merge convolution and BN'
347             #
348             merged_bias = (conv_bias - bn.running_mean) * merged_scale.view(-1) + bn_bias
349             merged_weight = conv.weight * merged_scale
350             #
351             merged_scale_sign = merged_scale.sign()
352             merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
353             merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
354             merged_scale_inv = 1.0 / merged_scale_eps
355             #
356         elif conv is not None:
357             merged_weight = conv.weight
358             merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
359             merged_scale = 1.0
360             merged_scale_inv = 1.0
361         elif bn is not None:
362             merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
363             merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
364         else:
365             assert False, f'merge_quantize_weights(): both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
366             merged_weight = 0.0
367             merged_bias = 0.0
368         #
370         # quantize weight and bias
371         if (conv is not None):
372             if (self.quantize_enable and self.quantize_weights):
373                 if self.constrain_weights and first_training_iter:
374                     with torch.no_grad():
375                         # clamp merged weights, invert the bn and copy to conv weight
376                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
377                         merged_weight.data.copy_(constrained_weight.data)
378                         # store clipped weight after inverting bn - not really needed as there is a saving below as well
379                         # conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
380                     #
381                 #
383                 is_dw = utils.is_dwconv(conv)
384                 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
385                 if use_per_channel_q:
386                     channels = int(merged_weight.size(0))
387                     scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
388                     scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
389                     for chan_id in range(channels):
390                         clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
391                         scale2[chan_id,0,0,0] = scale2_value
392                         scale_inv2[chan_id,0,0,0] = scale_inv2_value
393                     #
394                 else:
395                     clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
396                 #
397                 width_min, width_max = self.get_widths_w()
398                 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
399                 merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
400             #
402             if (self.quantize_enable and self.quantize_bias):
403                 bias_width_min, bias_width_max = self.get_widths_bias()
404                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
405                 # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
406                 merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
407             #
409             # invert the bn operation and store weights/bias
410             if first_training_iter or (self.training and is_store_weight_bias_iter):
411                 with torch.no_grad():
412                     if self.quantize_enable and self.quantize_weights:
413                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
414                     #
415                     if self.quantize_enable and self.quantize_bias:
416                         if conv.bias is not None:
417                             if bn is not None:
418                                 conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
419                                 conv.bias.data.copy_(conv_bias.data)
420                             else:
421                                 conv.bias.data.copy_(merged_bias.data)
422                             #
423                         elif bn is not None and bn.bias is not None:
424                             bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
425                             bn.bias.data.copy_(bn_bias.data)
426                         #
427                     #
428                 #
429             #
430         #
431         return conv, merged_weight, merged_bias
434     def get_clips_w(self, tensor):
435         # find the clip values
436         w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
437         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
438         clip_max = torch.clamp(clip_max, min=self.eps)
439         # in range learning mode + training - this power2 is taken care in the quantize function
440         use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
441         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
442         clip_min2 = -clip_max2
443         return (clip_min2, clip_max2)
445     # bias uses the same kind of clips
446     get_clips_bias = get_clips_w
449     def get_clips_scale_w(self, weight):
450         # convert to scale
451         clip_min, clip_max = self.get_clips_w(weight)
452         width_min, width_max = self.get_widths_w()
453         scale2 = (width_max / clip_max)
454         scale2 = torch.clamp(scale2, min=self.eps)
455         scale_inv2 = scale2.pow(-1.0)
456         return (clip_min, clip_max, scale2, scale_inv2)
459     # in reality, bias quantization will also depend on the activation scale
460     # this is not perfect - just a quick and dirty quantization for bias
461     def get_clips_scale_bias(self, bias):
462         # convert to scale
463         clip_min, clip_max = self.get_clips_bias(bias)
464         width_min, width_max = self.get_widths_bias()
465         scale2 = (width_max / clip_max)
466         scale2 = torch.clamp(scale2, min=self.eps)
467         scale_inv2 = scale2.pow(-1.0)
468         return (clip_min, clip_max, scale2, scale_inv2)
471     def get_widths_w(self):
472         # weights
473         bw = (self.bitwidth_weights - 1)
474         width_max = np.power(2.0, bw)
475         width_min = -width_max
476         # return
477         return (width_min, width_max)
480     def get_widths_bias(self):
481         # bias
482         bitwidth_bias = (2*self.bitwidth_activations)
483         bias_width_max = np.power(2.0, bitwidth_bias-1)
484         bias_width_min = -bias_width_max
485         # return
486         return (bias_width_min, bias_width_max)
489     # activation utility functions
490     def get_clips_scale_act(self):
491         # convert to scale
492         clip_min, clip_max = self.get_clips_act()
493         width_min, width_max = self.get_widths_act()
494         scale2 = width_max / clip_max
495         scale2 = torch.clamp(scale2, min=self.eps)
496         scale_inv2 = scale2.pow(-1.0)
497         return (clip_min, clip_max, scale2, scale_inv2)
500     def get_widths_act(self):
501         if self.signed is None:
502             clip_min, clip_max = self.get_clips_act()
503             signed = (clip_min < 0.0)
504         else:
505             signed = self.signed
506         #
507         bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
508         width_max = np.power(2.0, bw)
509         width_min = -width_max if signed else 0.0
510         return width_min, width_max