]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
updated quantization modules to support mmdetection, using Hardtanh for fixed range...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_calib_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_module import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantCalibrateModule(QuantTrainModule):
20     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
21                  histogram_range=True, bias_calibration=True, constrain_weights=None,
22                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05, **kwargs):
23         self.weights_calibration = False
24         self.lr_calib = lr_calib
25         self.calibration_factor = lr_calib
26         self.calibration_gamma = 0.5
27         self.calibrate_repeats = 1
28         self.quantize_enable = True
29         self.update_activation_range = True
30         constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
31         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
32                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
33                          power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias, **kwargs)
36     def forward(self, inputs, *args, **kwargs):
37         # calibration doesn't need gradients
38         with torch.no_grad():
39             # counters such as num_batches_tracked are used. update them.
40             self.update_counters()
42             # backup the current state
43             training = self.training
45             # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
46             # we need the pact to learn the ranges - which will happen only in training mode.
47             # Also the model output itself may be different in eval mode (in certain cases -
48             # for example if in a segmentation model argmax is done instead of softmax in eval mode).
49             utils.freeze_bn(self)
51             # actual forward call
52             if self.training and (self.bias_calibration or self.weights_calibration):
53                 # calibration
54                 outputs = self.forward_calibrate(inputs, *args, **kwargs)
55             else:
56                 outputs = self.module(inputs, *args, **kwargs)
57             #
59             self.train(training)
60         #
61         return outputs
64     def forward_calibrate(self, inputs, *args, **kwargs):
65         # we don't need gradients for calibration
66         # prepare/backup weights
67         if self.num_batches_tracked == 0:
68             # lr_calib
69             self.calibration_factor = self.lr_calib * np.power(self.calibration_gamma, float(self.epoch))
70             # backup original weights
71             self._backup_weights_orig()
72             # backup quantized weights
73             self._backup_weights_quant()
74         #
76         # Compute the mean output in float first.
77         outputs = self.forward_float(inputs, *args, **kwargs)
78         # Then adjust weights/bias so that the quantized output matches float output
79         outputs = self.forward_quantized(inputs, *args, **kwargs)
81         return outputs
84     def forward_float(self, inputs, *args, **kwargs):
85         self._restore_weights_orig()
86         # disable quantization for a moment
87         quantize_enable_backup_value, update_activation_range_backup_value = self.quantize_enable, self.update_activation_range
88         utils.apply_setattr(self, quantize_enable=False, update_activation_range=False)
90         self.add_call_hook(self.module, self.forward_float_hook)
91         outputs = self.module(inputs, *args, **kwargs)
92         self.remove_call_hook(self.module)
94         # turn quantization back on - not a clean method
95         utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_activation_range=update_activation_range_backup_value)
96         self._backup_weights_orig()
97         return outputs
98     #
99     def forward_float_hook(self, op, *inputs_orig):
100         outputs = op.__forward_orig__(*inputs_orig)
102         # calibration at specific layers
103         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
104         while isinstance(output, (list, tuple)):
105             output = output[0]
107         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
109         bias = op.bias if hasattr(op, 'bias') else None
110         if (self.bias_calibration and bias is not None):
111             op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims).data
112         #
114         if self.weights_calibration and utils.is_conv_deconv(op):
115             op.__output_std_orig__ = torch.std(output, dim=reduce_dims).data
116         #
117         return outputs
118     #
121     def forward_quantized(self, input, *args, **kwargs):
122         self._restore_weights_quant()
123         self.add_call_hook(self.module, self.forward_quantized_hook)
124         for _ in range(self.calibrate_repeats):
125             output = self.module(input, *args, **kwargs)
126         #
127         self.remove_call_hook(self.module)
128         self._backup_weights_quant()
129         return output
130     #
131     def forward_quantized_hook(self, op, input, *args, **kwargs):
132         outputs = op.__forward_orig__(input, *args, **kwargs)
134         # calibration at specific layers
135         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
136         while isinstance(output, (list, tuple)):
137             output = output[0]
139         bias = op.bias if hasattr(op, 'bias') else None
140         if self.bias_calibration and bias is not None:
141             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
142             output_mean = torch.mean(output, dim=reduce_dims).data
143             output_delta = op.__output_mean_orig__ - output_mean
144             output_delta = output_delta * self.calibration_factor
145             bias.data += (output_delta)
146         #
148         if self.weights_calibration and utils.is_conv_deconv(op):
149             eps = 1e-6
150             weight = op.weight
151             reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
152             output_std = torch.std(output, dim=reduce_dims).data
153             output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
154             channels = output.size(1)
155             output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
156             output_ratio = torch.pow(output_ratio, self.calibration_factor)
157             weight.data *= output_ratio
158         #
159         return outputs