[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_calib_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_module import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantCalibrateModule(QuantTrainModule):
20 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
21 histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
22 self.weights_calibration = False
23 self.lr_calib = lr_calib
24 self.calibration_factor = lr_calib
25 self.calibration_gamma = 0.5
26 self.calibrate_repeats = 1
27 self.quantize_enable = True
28 self.update_range = True
29 constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
30 super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
31 per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
32 constrain_weights=constrain_weights)
35 def forward(self, inputs):
36 # calibration doesn't need gradients
37 with torch.no_grad():
38 # counters such as num_batches_tracked are used. update them.
39 self.update_counters()
41 # backup the current state
42 training = self.training
44 # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
45 # we need the pact to learn the ranges - which will happen only in training mode.
46 # Also the model output itself may be different in eval mode (in certain cases -
47 # for example if in a segmentation model argmax is done instead of softmax in eval mode).
48 utils.freeze_bn(self)
50 # actual forward call
51 if self.training and (self.bias_calibration or self.weights_calibration):
52 # calibration
53 outputs = self.forward_calibrate(inputs)
54 else:
55 outputs = self.module(inputs)
56 #
58 self.train(training)
59 #
60 return outputs
63 def forward_calibrate(self, inputs):
64 # we don't need gradients for calibration
65 # prepare/backup weights
66 if self.num_batches_tracked == 0:
67 # lr_calib
68 self.calibration_factor = self.lr_calib * np.power(self.calibration_gamma, float(self.epoch))
69 # backup original weights
70 self._backup_weights_orig()
71 # backup quantized weights
72 self._backup_weights_quant()
73 #
75 # Compute the mean output in float first.
76 outputs = self.forward_float(inputs)
77 # Then adjust weights/bias so that the quantized output matches float output
78 outputs = self.forward_quantized(inputs)
80 return outputs
83 def forward_float(self, inputs):
84 self._restore_weights_orig()
85 # disable quantization for a moment
86 quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
87 utils.apply_setattr(self, quantize_enable=False, update_range=False)
89 self.add_call_hook(self.module, self.forward_float_hook)
90 outputs = self.module(inputs)
91 self.remove_call_hook(self.module)
93 # turn quantization back on - not a clean method
94 utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
95 self._backup_weights_orig()
96 return outputs
97 #
98 def forward_float_hook(self, op, *inputs_orig):
99 outputs = op.__forward_orig__(*inputs_orig)
101 # calibration at specific layers
102 output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
103 reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
105 bias = op.bias if hasattr(op, 'bias') else None
106 if (self.bias_calibration and bias is not None):
107 op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims).data
108 #
110 if self.weights_calibration and utils.is_conv_deconv(op):
111 op.__output_std_orig__ = torch.std(output, dim=reduce_dims).data
112 #
113 return outputs
114 #
117 def forward_quantized(self, input):
118 self._restore_weights_quant()
119 self.add_call_hook(self.module, self.forward_quantized_hook)
120 for _ in range(self.calibrate_repeats):
121 output = self.module(input)
122 #
123 self.remove_call_hook(self.module)
124 self._backup_weights_quant()
125 return output
126 #
127 def forward_quantized_hook(self, op, *inputs_orig):
128 outputs = op.__forward_orig__(*inputs_orig)
130 # calibration at specific layers
131 output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
133 bias = op.bias if hasattr(op, 'bias') else None
134 if self.bias_calibration and bias is not None:
135 reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
136 output_mean = torch.mean(output, dim=reduce_dims).data
137 output_delta = op.__output_mean_orig__ - output_mean
138 output_delta = output_delta * self.calibration_factor
139 bias.data += (output_delta)
140 #
142 if self.weights_calibration and utils.is_conv_deconv(op):
143 eps = 1e-6
144 weight = op.weight
145 reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
146 output_std = torch.std(output, dim=reduce_dims).data
147 output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
148 channels = output.size(1)
149 output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
150 output_ratio = torch.pow(output_ratio, self.calibration_factor)
151 weight.data *= output_ratio
152 #
153 return outputs