summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: fe0de31)
raw | patch | inline | side by side (parent: fe0de31)
author | Manu Mathew <a0393608@ti.com> | |
Wed, 27 May 2020 12:42:48 +0000 (18:12 +0530) | ||
committer | Manu Mathew <a0393608@ti.com> | |
Wed, 27 May 2020 13:11:54 +0000 (18:41 +0530) |
release commit
diff --git a/modules/pytorch_jacinto_ai/vision/models/classification/__init__.py b/modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
index 8331aa19b7443420c0cbbf3ce060e68643704f9b..cf4437a81434979eeed1239f9c63c55bade216df 100644 (file)
#####################################################################
#####################################################################
-def resnet50_x1(model_config, pretrained=None):
+def resnet50_x1(model_config=None, pretrained=None):
model_config = resnet.get_config().merge_from(model_config)
model = resnet.resnet50_with_model_config(model_config)
# the pretrained model provided by torchvision and what is defined here differs slightly
model_config = resnet.get_config().merge_from(model_config)
model = resnet.resnet50_with_model_config(model_config)
# the pretrained model provided by torchvision and what is defined here differs slightly
return model, change_names_dict
return model, change_names_dict
-def resnet50_xp5(model_config, pretrained=None):
+def resnet50_xp5(model_config=None, pretrained=None):
model_config.width_mult = 0.5
return resnet50_x1(model_config=model_config, pretrained=pretrained)
#####################################################################
model_config.width_mult = 0.5
return resnet50_x1(model_config=model_config, pretrained=pretrained)
#####################################################################
-def resnet18_x1(model_config, pretrained=None):
+def resnet18_x1(model_config=None, pretrained=None):
model_config = resnet.get_config().merge_from(model_config)
model = resnet.resnet18_with_model_config(model_config)
# the pretrained model provided by torchvision and what is defined here differs slightly
model_config = resnet.get_config().merge_from(model_config)
model = resnet.resnet18_with_model_config(model_config)
# the pretrained model provided by torchvision and what is defined here differs slightly
#####################################################################
#####################################################################
-def mobilenetv1_x1(model_config, pretrained=None):
+def mobilenetv1_x1(model_config=None, pretrained=None):
model_config = mobilenetv1.get_config().merge_from(model_config)
model = mobilenetv1.MobileNetV1(model_config=model_config)
if pretrained:
model = xnn.utils.load_weights(model, pretrained)
return model
model_config = mobilenetv1.get_config().merge_from(model_config)
model = mobilenetv1.MobileNetV1(model_config=model_config)
if pretrained:
model = xnn.utils.load_weights(model, pretrained)
return model
-def mobilenetv1_multi_label_x1(model_config, pretrained=None):
+def mobilenetv1_multi_label_x1(model_config=None, pretrained=None):
model_config = mobilenetv1.get_config().merge_from(model_config)
model = mobilenetv1_internal.MobileNetV1MultiLabel(model_config=model_config)
if pretrained:
model_config = mobilenetv1.get_config().merge_from(model_config)
model = mobilenetv1_internal.MobileNetV1MultiLabel(model_config=model_config)
if pretrained:
#####################################################################
#####################################################################
-def mobilenetv2_tv_x1(model_config, pretrained=None):
+def mobilenetv2_tv_x1(model_config=None, pretrained=None):
model_config = mobilenetv2.get_config().merge_from(model_config)
model = mobilenetv2.MobileNetV2TV(model_config=model_config)
if pretrained:
model_config = mobilenetv2.get_config().merge_from(model_config)
model = mobilenetv2.MobileNetV2TV(model_config=model_config)
if pretrained:
mobilenetv2_x1 = mobilenetv2_tv_x1
mobilenetv2_x1 = mobilenetv2_tv_x1
-def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
+def mobilenetv2_tv_x2_t2(model_config=None, pretrained=None):
model_config = mobilenetv2.get_config().merge_from(model_config)
model_config.width_mult = 2.0
model_config.expand_ratio = 2.0
model_config = mobilenetv2.get_config().merge_from(model_config)
model_config.width_mult = 2.0
model_config.expand_ratio = 2.0
#####################################################################
#####################################################################
-def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
+def mobilenetv2_tv_gws_x1(model_config=None, pretrained=None):
model_config = mobilenetv2_internal.get_config_mnetv2_gws().merge_from(model_config)
model = mobilenetv2_internal.MobileNetV2TVGWS(model_config=model_config)
if pretrained:
model_config = mobilenetv2_internal.get_config_mnetv2_gws().merge_from(model_config)
model = mobilenetv2_internal.MobileNetV2TVGWS(model_config=model_config)
if pretrained:
#####################################################################
#####################################################################
-def mobilenetv2_ericsun_x1(model_config, pretrained=None):
+def mobilenetv2_ericsun_x1(model_config=None, pretrained=None):
model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
if pretrained:
model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
if pretrained:
return model
return model
-def mobilenetv2_shicai_x1(model_config, pretrained=None):
+def mobilenetv2_shicai_x1(model_config=None, pretrained=None):
model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
if pretrained:
model = xnn.utils.load_weights(model, pretrained)
return model
model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
if pretrained:
model = xnn.utils.load_weights(model, pretrained)
return model
-
-def flownetslite_base_x1(model_config, pretrained=None):
- model_config = flownetbase_internal.get_config().merge_from(model_config)
- model = flownetbase_internal.flownetslite_base(model_config, pretrained=pretrained)
- if pretrained:
- model = xnn.utils.load_weights(model, pretrained)
- return model
\ No newline at end of file
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/activation.py b/modules/pytorch_jacinto_ai/xnn/layers/activation.py
index 325764e0caaea4cc4fb3dbe87ba9283f709fb41a..a1e6ee68af0dab1a6891e821eea07959b237b5a8 100644 (file)
PACT2_RANGE_INIT = 8.0 # this is the starting range
PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
PACT2_RANGE_INIT = 8.0 # this is the starting range
PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
- def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, **kwargs):
+ def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, power2_activation_range=True, **kwargs):
super().__init__()
if (clip_range is not None) and (signed is not None):
assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
super().__init__()
if (clip_range is not None) and (signed is not None):
assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
self.fixed_range = (clip_range is not None)
self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
self.eps = np.power(2.0, -16.0)
self.fixed_range = (clip_range is not None)
self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
self.eps = np.power(2.0, -16.0)
- self.power2 = True # power of 2 ranges
+ self.power2_activation_range = power2_activation_range # power of 2 ranges
self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
# any validation before at-least one iteration of training wll use default clip values.
self.log_base = None # 2.0 # log is used only in learned mode if log_base is not None
# any validation before at-least one iteration of training wll use default clip values.
#
#
- def forward(self, x, update_range=True, enable=True):
- if (self.training and update_range):
+ def forward(self, x, update_activation_range=True, enable=True):
+ if (self.training and update_activation_range):
self.num_batches_tracked += 1
# even in learn_range mode - do this for a few iterations to get a good starting point
if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
self.num_batches_tracked += 1
# even in learn_range mode - do this for a few iterations to get a good starting point
if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
#
if not enable:
signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
#
if not enable:
signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
- return x if signed else torch.nn.functional.relu(x)
+ y = x if signed else torch.nn.functional.relu(x)
+ else:
+ clips = self.get_clips_act()
+ y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
#
#
- clips = self.get_clips_act()
- y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
return y
return y
clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
clip_max = torch.clamp(clip_max, min=self.eps)
clip_max = self.convert_to_linear(clip_max)
clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
clip_max = torch.clamp(clip_max, min=self.eps)
clip_max = self.convert_to_linear(clip_max)
- # in range learning mode + training - this power2 is taken care in the quantize function
- use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
- clip_max2 = ceil2_g(clip_max) if use_power2 else clip_max
+ # in range learning mode + training - this power2_activation_range is taken care in the quantize function
+ is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+ use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
+ clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
return (clip_min2, clip_max2)
clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
return (clip_min2, clip_max2)
###############################################################
###############################################################
-class NoAct(torch.nn.Module):
+# Always quantized activation function.
+# Inserting this activation function is a simple way to ensure quantization happens at certain places.
+class QAct(torch.nn.Module):
+ def __init__(self, inplace=False, signed=True, **kwargs):
+ super().__init__()
+ self.inplace = inplace
+ self.signed = signed
+
+ def forward(self, x):
+ return x
+
+
+# Never quantized activation function.
+# Also if the next block is this, the previous block output is also not quantized.
+# Inserting this activation function is a simple way to avoid quantization at certain places.
+class NoQAct(torch.nn.Module):
def __init__(self, inplace=False, signed=True, **kwargs):
super().__init__()
self.inplace = inplace
def __init__(self, inplace=False, signed=True, **kwargs):
super().__init__()
self.inplace = inplace
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
index 30a8f91d58116c0670b30d975a2968cd567948d4..fa6b75396e10e1ac25de878c5c6a2b02aae217fe 100644 (file)
# base module to be use for all quantization modules
class QuantBaseModule(QuantGraphModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
# base module to be use for all quantization modules
class QuantBaseModule(QuantGraphModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
- histogram_range=True, bias_calibration=False, constrain_weights=False,
- model_surgery_quantize=False):
+ histogram_range=True, bias_calibration=False, constrain_weights=False, constrain_bias=None,
+ model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None):
super().__init__(module)
self.bitwidth_weights = bitwidth_weights
self.bitwidth_activations = bitwidth_activations
self.per_channel_q = per_channel_q
self.histogram_range = histogram_range
self.constrain_weights = constrain_weights
super().__init__(module)
self.bitwidth_weights = bitwidth_weights
self.bitwidth_activations = bitwidth_activations
self.per_channel_q = per_channel_q
self.histogram_range = histogram_range
self.constrain_weights = constrain_weights
+ self.constrain_bias = True if (constrain_bias is None) else constrain_bias
self.bias_calibration = bias_calibration
self.bias_calibration = bias_calibration
+ self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
+ self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
+ # range shrink - 0.0 indicates no shrink
+ self.percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
# for help in debug/print
utils.add_module_names(self)
# put in eval mode before analyze
# for help in debug/print
utils.add_module_names(self)
# put in eval mode before analyze
# for help in debug/print
utils.add_module_names(self)
# for help in debug/print
utils.add_module_names(self)
+ # set attributes to all modules - can control the behaviour from here
+ utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
+ histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
+ percentile_range_shrink=self.percentile_range_shrink, constrain_weights=self.constrain_weights,
+ power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
+ constrain_bias=self.constrain_bias)
def add_activation_hooks(self):
# add a forward hook to call the extra activation that we added
def add_activation_hooks(self):
# add a forward hook to call the extra activation that we added
- def _forward_activation(op, inputs, outputs):
+ def _forward_input_activation(op, inputs):
+ if hasattr(op, 'activation_in'):
+ # hook passes the input as tuple - expand it
+ to_squeeze = isinstance(inputs, tuple) and len(inputs) == 1
+ inputs = inputs[0] if to_squeeze else inputs
+ inputs = op.activation_in(inputs)
+ inputs = (inputs,) if to_squeeze else inputs
+ #
+ return inputs
+ #
+ def _forward_output_activation(op, inputs, outputs):
if hasattr(op, 'activation_q'):
if hasattr(op, 'activation_q'):
+ # hook passes the input as tuple - expand it
+ to_squeeze = isinstance(outputs, tuple) and len(outputs) == 1
+ outputs = outputs[0] if to_squeeze else outputs
outputs = op.activation_q(outputs)
outputs = op.activation_q(outputs)
+ outputs = (outputs,) if to_squeeze else outputs
#
return outputs
#
for m in self.modules():
#
return outputs
#
for m in self.modules():
- m.register_forward_hook(_forward_activation)
+ m.register_forward_pre_hook(_forward_input_activation)
+ m.register_forward_hook(_forward_output_activation)
#
#
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
index 266b0dabc0a3aed066194f9cf84f70a8899b457b..954ea45368fa25a72eeb16f59e28c63f545e1d58 100644 (file)
###########################################################
class QuantCalibrateModule(QuantTrainModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
###########################################################
class QuantCalibrateModule(QuantTrainModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
- histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
+ histogram_range=True, bias_calibration=True, constrain_weights=None,
+ power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05):
self.weights_calibration = False
self.lr_calib = lr_calib
self.calibration_factor = lr_calib
self.calibration_gamma = 0.5
self.calibrate_repeats = 1
self.quantize_enable = True
self.weights_calibration = False
self.lr_calib = lr_calib
self.calibration_factor = lr_calib
self.calibration_gamma = 0.5
self.calibrate_repeats = 1
self.quantize_enable = True
- self.update_range = True
+ self.update_activation_range = True
constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
- per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
- constrain_weights=constrain_weights)
+ per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
+ power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
def forward(self, inputs):
def forward(self, inputs):
def forward_float(self, inputs):
self._restore_weights_orig()
# disable quantization for a moment
def forward_float(self, inputs):
self._restore_weights_orig()
# disable quantization for a moment
- quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
- utils.apply_setattr(self, quantize_enable=False, update_range=False)
+ quantize_enable_backup_value, update_activation_range_backup_value = self.quantize_enable, self.update_activation_range
+ utils.apply_setattr(self, quantize_enable=False, update_activation_range=False)
self.add_call_hook(self.module, self.forward_float_hook)
outputs = self.module(inputs)
self.remove_call_hook(self.module)
# turn quantization back on - not a clean method
self.add_call_hook(self.module, self.forward_float_hook)
outputs = self.module(inputs)
self.remove_call_hook(self.module)
# turn quantization back on - not a clean method
- utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
+ utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_activation_range=update_activation_range_backup_value)
self._backup_weights_orig()
return outputs
#
self._backup_weights_orig()
return outputs
#
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
index 1086cb559fa50063a7d95bd99e3adff9a9c95d4c..cbcb3c1ebb6fd8c4a465ab5d1eb80d0b87b53cec 100644 (file)
self.num_batches_tracked = -1
self.iter_in_epoch = -1
self.epoch = -1
self.num_batches_tracked = -1
self.iter_in_epoch = -1
self.epoch = -1
+ # these are the blocks whose output we quantize for sure.
+ # outputs of other clocks such as Conv2d, ConvTranspose2d, BatchNorm2d, Lindear are quantized conditionally
+ self.quantize_out_blocks = (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh, layers.QAct, layers.PAct2,
+ layers.AddBlock, layers.CatBlock, layers.MultBlock, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
+
+ # this block is not quantized. Also if the next block is this, current block is not quantized
+ self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
# TBD: is this required
# # if the original module has load_weights, add it to the quant module also
# TBD: is this required
# # if the original module has load_weights, add it to the quant module also
self.clear_qstate()
# analyze
self.analyze_graph(dummy_input)
self.clear_qstate()
# analyze
self.analyze_graph(dummy_input)
- # insert NoAct wherever range clipping needs to be done
+ # insert QAct wherever range clipping needs to be done
self.model_surgery_activations()
# since we might have added new activations, clear the sates as they may not be valid
self.clear_qstate()
self.model_surgery_activations()
# since we might have added new activations, clear the sates as they may not be valid
self.clear_qstate()
activation_q = layers.PAct2(signed=False)
elif isinstance(module, torch.nn.Hardtanh):
activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
activation_q = layers.PAct2(signed=False)
elif isinstance(module, torch.nn.Hardtanh):
activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
- elif isinstance(module, layers.NoAct):
+ elif isinstance(module, layers.QAct):
activation_q = layers.PAct2(signed=None)
else:
activation_q = layers.PAct2(signed=None)
activation_q = layers.PAct2(signed=None)
else:
activation_q = layers.PAct2(signed=None)
activation_q.train(self.training)
module.activation_q = activation_q
#
activation_q.train(self.training)
module.activation_q = activation_q
#
+ elif qparams.quantize_in:
+ if not hasattr(module, 'activation_in'):
+ activation_in = layers.PAct2(signed=None)
+ activation_in.train(self.training)
+ module.activation_in = activation_in
+ #
else:
pass
#
else:
pass
#
def _forward_analyze_modules_impl(self, inputs):
self.start_call()
def _forward_analyze_modules_impl(self, inputs):
self.start_call()
- self.add_call_hook(self.module, self._analyze_modules_op)
+ self.add_call_hook(self, self._analyze_modules_op)
output = self.module(inputs)
self.remove_call_hook(self.module)
self.finish_call()
return output
def _analyze_modules_op(self, op, *inputs_orig):
output = self.module(inputs)
self.remove_call_hook(self.module)
self.finish_call()
return output
def _analyze_modules_op(self, op, *inputs_orig):
- inputs = utils.squeeze_list(inputs_orig)
+ inputs = utils.squeeze_list2(inputs_orig)
self.start_node(op)
self.add_node(op, inputs)
outputs = op.__forward_orig__(*inputs_orig)
self.start_node(op)
self.add_node(op, inputs)
outputs = op.__forward_orig__(*inputs_orig)
################################################################
def analyze_connections(self):
################################################################
def analyze_connections(self):
+ first_module = None
prediction_module = None
for module_hash, qparams in self.get_qstate().qparams.items():
module = self.get_module(module_hash)
prediction_module = None
for module_hash, qparams in self.get_qstate().qparams.items():
module = self.get_module(module_hash)
- if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
+ if utils.is_conv_deconv_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
+ first_module = module if first_module is None else first_module
prediction_module = module
#
#
for module_hash, qparams in self.get_qstate().qparams.items():
module = self.get_module(module_hash)
prediction_module = module
#
#
for module_hash, qparams in self.get_qstate().qparams.items():
module = self.get_module(module_hash)
- is_prediction = (prediction_module is module)
- self._analyse_connections_op(module_hash, module, qparams, is_prediction)
+ is_first_module = (first_module is module)
+ is_prediction_module = (prediction_module is module)
+ self._analyse_connections_op(module_hash, module, qparams, is_first_module, is_prediction_module)
#
#
- def _analyse_connections_op(self, module_hash, module, qparams, is_prediction):
- previous_module = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
- next_module = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
+ def _analyse_connections_op(self, module_hash, module, qparams, is_first_module, is_prediction_module):
+ previous_modules = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
+ next_modules = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
quantize_out = False
quantize_out = False
- if utils.is_activation(module):
- if len(next_module)==1 and utils.is_activation(next_module[0]):
+ if isinstance(module, self.ignore_out_blocks):
+ quantize_out = False
+ elif utils.is_activation(module):
+ if len(next_modules)==1 and utils.is_activation(next_modules[0]):
quantize_out = False
else:
quantize_out = True
#
quantize_out = False
else:
quantize_out = True
#
- elif isinstance(module, (layers.AddBlock, layers.CatBlock, layers.MultBlock)):
- if len(next_module)==1 and utils.is_activation(next_module[0]):
+ elif isinstance(module, self.quantize_out_blocks):
+ if len(next_modules)==1 and utils.is_activation(next_modules[0]):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_normalization(module):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_normalization(module):
- if len(next_module)==1 and utils.is_activation(next_module[0]):
+ if len(next_modules)==1 and utils.is_activation(next_modules[0]):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_conv(module) or utils.is_deconv(module):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_conv(module) or utils.is_deconv(module):
- if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
+ if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_linear(module):
quantize_out = False
else:
quantize_out = True
#
elif utils.is_linear(module):
- if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
+ if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
quantize_out = False
else:
quantize_out = True
quantize_out = False
else:
quantize_out = True
# quantize_out = True
# #
# quantize_out = True
# #
- qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized
- qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized
- qparams.quantize_out = quantize_out # selectively quantize output
- qparams.quantize_in = qparams.is_input # only top modules's input need to be quantized
- qparams.align_in = isinstance(module, (layers.AddBlock, layers.CatBlock,torch.nn.AdaptiveAvgPool2d))# all tensors to be made same q at the input
- qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0 #additional scaleup to simulate fixed point
- qparams.unquantize_out = qparams.is_input # only top modules's output need to be unquantized
+ if len(qparams.previous_node) > 0:
+ previous_module_hash = qparams.previous_node[-1]
+ previous_module = self.get_module(previous_module_hash)
+ previous_module_qparams = self.get_qstate().qparams[previous_module_hash]
+ is_input_ignored = isinstance(previous_module, self.ignore_out_blocks)
+ is_input_quantized = previous_module_qparams.quantize_out if \
+ hasattr(previous_module_qparams, 'quantize_out') else False
+ else:
+ is_input_ignored = False
+ is_input_quantized = False
+ #
+
+ quantize_in = utils.is_conv_deconv_linear(module) and not is_input_quantized and \
+ not is_input_ignored and is_first_module
+ qparams.quantize_w = utils.is_conv_deconv_linear(module) # all conv/deconv layers will be quantized
+ qparams.quantize_b = utils.is_conv_deconv_linear(module) # all conv/deconv layers will be quantized
+ qparams.quantize_out = quantize_out # selectively quantize output
+ qparams.quantize_in = quantize_in # only top modules's input need to be quantized
+ multi_input_blocks = (layers.AddBlock, layers.CatBlock, torch.nn.AdaptiveAvgPool2d)
+ qparams.align_in = isinstance(module, multi_input_blocks) # all tensors to be made same q at the input
+ qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0 # additional scaleup to simulate fixed point
+ qparams.unquantize_out = qparams.is_input # only top modules's output need to be unquantized
qparams.is_dwconv = utils.is_dwconv(module)
qparams.is_dwconv = utils.is_dwconv(module)
- qparams.next_module = next_module
- qparams.is_prediction = is_prediction
+ qparams.next_modules = next_modules
+ qparams.is_first_module = is_first_module
+ qparams.is_prediction_module = is_prediction_module
################################################################
################################################################
is_conv = utils.is_conv_deconv(module)
# note: we consider merging only if there is a single next node
is_conv = utils.is_conv_deconv(module)
# note: we consider merging only if there is a single next node
- next_module = qparams.next_module[0] if len(qparams.next_module) == 1 else None
+ next_module = qparams.next_modules[0] if len(qparams.next_modules) == 1 else None
next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
# if the next module is a bn, appy bn merging step
next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
# if the next module is a bn, appy bn merging step
def format_tensors(self, inputs):
# make a list/tuple if inputs is not. if it is a double list, remove the extra one
def format_tensors(self, inputs):
# make a list/tuple if inputs is not. if it is a double list, remove the extra one
- inputs = utils.squeeze_list(utils.make_list(inputs))
+ inputs = utils.squeeze_list2(utils.make_list(inputs))
# remove lists/tuple
inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
return inputs
# remove lists/tuple
inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
return inputs
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
index f0d47898ac97938f398e08a032bc23ca0dc7d092..76a139160718e44658dfca591a108eb12e5b4c6f 100644 (file)
import warnings
import numpy as np
import warnings
import numpy as np
+
+########################################################################
+from .quant_train_module import *
+
+class QuantTestModule(QuantTrainModule):
+ def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+ histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
+ power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
+ constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
+ super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+ per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+ constrain_weights=constrain_weights,
+ power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+ assert model_surgery_quantize == True, f'{self.__class__.__name__} does not support model_surgery_quantize=False. please use a qat or calibrated module.'
+ self.eval()
+
+
+ def train(self, mode=True):
+ assert mode == False, 'QuantTestModule cannot be used in train mode'
+ super().train(mode)
+
+
+########################################################################
from .quant_base_module import *
from .quant_utils import *
from .quant_base_module import *
from .quant_utils import *
-class QuantTestModule(QuantBaseModule):
- def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
- range_calibration_online=False, model_surgery_quantize=True):
+class QuantEstimateModule(QuantBaseModule):
+ '''
+ QuantEstimateModule can be used to estimate the quantization accuracy of a float model
+ that has not gone through QAT or Calibration. However, this is an approximate method.
+ '''
+ def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+ histogram_range=True, range_calibration_online=False, model_surgery_quantize=True,
+ power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
- constrain_weights=False, model_surgery_quantize=model_surgery_quantize)
- # use power2_weights for now
- self.power2_weights = True
+ constrain_weights=False, model_surgery_quantize=model_surgery_quantize,
+ power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+ assert False, 'recommend to use QuantTestModule instead'
# whether to do online adjustment of calibration using previous frame range
self.range_calibration_online = range_calibration_online
# number of offline calibration iters. during offline calibration, current frame range is used
# whether to do online adjustment of calibration using previous frame range
self.range_calibration_online = range_calibration_online
# number of offline calibration iters. during offline calibration, current frame range is used
def replace_func(op):
for name, m in op._modules.items():
def replace_func(op):
for name, m in op._modules.items():
- if isinstance(m, layers.NoAct):
+ if isinstance(m, layers.QAct):
new_m = layers.PAct2(signed=None)
else:
new_m = None
new_m = layers.PAct2(signed=None)
else:
new_m = None
def _forward_quantize_hook(self, op, *inputs_orig):
def _forward_quantize_hook(self, op, *inputs_orig):
- inputs = utils.squeeze_list(inputs_orig)
+ inputs = utils.squeeze_list2(inputs_orig)
self.start_node(op)
self.start_quantize(op)
self.start_node(op)
self.start_quantize(op)
if qparams.quantize_w and weight is not None:
qparams.qrange_w = Dict()
if qparams.quantize_w and weight is not None:
qparams.qrange_w = Dict()
- self.quantize_weights(module, weight, qparams.qrange_w)
+ self.quantize_weights_tensor(module, weight, qparams.qrange_w)
else:
qparams.qrange_w = None
if qparams.quantize_b and bias is not None:
qparams.qparams_b = Dict()
else:
qparams.qrange_w = None
if qparams.quantize_b and bias is not None:
qparams.qparams_b = Dict()
- self.quantize_bias(module, bias, qparams.qparams_b)
+ self.quantize_bias_tensor(module, bias, qparams.qparams_b)
else:
qparams.qparams_b = None
else:
qparams.qparams_b = None
for inp in inputs:
inp.scale = inp.scale if hasattr(inp,'scale') else self.current_scale
for inp in inputs:
inp.scale = inp.scale if hasattr(inp,'scale') else self.current_scale
- qrange_cur = self.quantize_inputs(module, inputs, outputs, qparams_prev, qparams)
+ qrange_cur = self.quantize_input_tensors(module, inputs, outputs, qparams_prev, qparams)
# create the current scale in proccess_inputs instead of process_outputs.
# otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
# create the current scale in proccess_inputs instead of process_outputs.
# otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
- # all the inputs scales are assumed to be aligned at this point (see align_inputs)
- # any module that needs special handling needs to be considered in quantize_inputs / align_inputs.
+ # all the inputs scales are assumed to be aligned at this point (see align_input_tensors)
+ # any module that needs special handling needs to be considered in quantize_input_tensors / align_input_tensors.
has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
if has_weight_scale:
is_dw = utils.is_dwconv(module)
has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
if has_weight_scale:
is_dw = utils.is_dwconv(module)
for idx, opt in enumerate(output):
opt.scale = self.current_scale
for idx, opt in enumerate(output):
opt.scale = self.current_scale
- qrange_cur = self.quantize_outputs(module, inputs, output, qparams_prev, qparams)
+ qrange_cur = self.quantize_output_tensors(module, inputs, output, qparams_prev, qparams)
self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
self.current_scale = output[0].scale
self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
self.current_scale = output[0].scale
def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
- update_range = (is_calibration or self.range_calibration_online)
- if update_range:
+ update_activation_range = (is_calibration or self.range_calibration_online)
+ if update_activation_range:
# in the case of fixed range module, we do not expand the ranges
fixed_range_module = utils.is_fixed_range(module)
if fixed_range_module:
# in the case of fixed range module, we do not expand the ranges
fixed_range_module = utils.is_fixed_range(module)
if fixed_range_module:
return bitwidth_activations
return bitwidth_activations
- def quantize_weights(self, module, tensor_in, qrange):
+ def quantize_weights_tensor(self, module, tensor_in, qrange):
self.apply_constrain_weights(module)
bitwidth_weights = self.get_bitwidth_weights(module)
self.apply_constrain_weights(module)
bitwidth_weights = self.get_bitwidth_weights(module)
for chan in range(tensor_in.shape[0]):
# Range
mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
for chan in range(tensor_in.shape[0]):
# Range
mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
- tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weights)
+ tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
qrange.min.append(mn)
qrange.max.append(mx)
# Quantize
qrange.min.append(mn)
qrange.max.append(mx)
# Quantize
else:
# Range
mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
else:
# Range
mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
- tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weights)
+ tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
qrange.min = mn
qrange.max = mx
# Quantize
qrange.min = mn
qrange.max = mx
# Quantize
tensor_in.scale = 1.0
tensor_in.scale = 1.0
- def quantize_bias(self, module, tensor_in, qparams):
+ def quantize_bias_tensor(self, module, tensor_in, qparams):
quant_for_bias = True
if quant_for_bias:
bitwidth_weights = self.get_bitwidth_weights(module)
quant_for_bias = True
if quant_for_bias:
bitwidth_weights = self.get_bitwidth_weights(module)
bitwidth_bias = bitwidth_weights
mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
bitwidth_bias = bitwidth_weights
mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
- tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weights)
+ tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
# --
tensor = symmetric_round_tensor(tensor_in * tensor_scale)
# --
tensor = symmetric_round_tensor(tensor_in * tensor_scale)
tensor_in.scale = 1.0
tensor_in.scale = 1.0
- def quantize_inputs(self, module, input, output, qparams_prev, qparams):
+ def quantize_input_tensors(self, module, input, output, qparams_prev, qparams):
qrange_cur = []
use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
for idx, inp in enumerate(input):
qrange_cur = []
use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
for idx, inp in enumerate(input):
return qrange_cur
return qrange_cur
- def quantize_outputs(self, module, input, output, qparams_prev, qparams):
+ def quantize_output_tensors(self, module, input, output, qparams_prev, qparams):
qrange_cur = []
use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
for idx, opt in enumerate(output):
qrange_cur = []
use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
for idx, opt in enumerate(output):
def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
- tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weights)
+ tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weight_range)
# print("mn : mx {} {}".format(mn, mx))
# print("mn : mx {} {}".format(mn, mx))
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
index e7601cdc40a60c95981beb51079c38197a37124f..f0b4529f1235068fad816207fa914b1a89195288 100644 (file)
###########################################################
class QuantTrainModule(QuantBaseModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
###########################################################
class QuantTrainModule(QuantBaseModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
- histogram_range=True, bias_calibration=False, constrain_weights=None):
+ histogram_range=True, bias_calibration=False, constrain_weights=None,
+ power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
- constrain_weights=constrain_weights, model_surgery_quantize=True)
- # range shrink - 0.0 indicates no shrink
- percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
- # set attributes to all modules - can control the behaviour from here
- utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
- per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
- percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
- update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
+ constrain_weights=constrain_weights, model_surgery_quantize=True,
+ power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
def forward(self, inputs):
# counters such as num_batches_tracked are used. update them.
def forward(self, inputs):
# counters such as num_batches_tracked are used. update them.
padding_mode = m.padding_mode
new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
padding_mode = m.padding_mode
new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
+ elif utils.is_linear(m):
+ bias = (m.bias is not None)
+ new_m = QuantTrainLinear(in_features=m.in_features, out_features=m.out_features, bias=bias)
elif utils.is_bn(m):
new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
track_running_stats=m.track_running_stats)
elif isinstance(m, layers.PAct2):
new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
elif utils.is_bn(m):
new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
track_running_stats=m.track_running_stats)
elif isinstance(m, layers.PAct2):
new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
- per_channel_q=self.per_channel_q)
- elif isinstance(m, layers.NoAct):
+ per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+ power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
+ elif isinstance(m, layers.QAct):
new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
- per_channel_q=self.per_channel_q)
+ per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+ power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
- per_channel_q=self.per_channel_q)
+ per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+ power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
else:
new_m = None
#
if new_m is not None:
else:
new_m = None
#
if new_m is not None:
+ copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w')
for attr in dir(m):
value = getattr(m,attr)
if isinstance(value,torch.Tensor) and value is not None:
getattr(new_m,attr).data.copy_(value.data)
for attr in dir(m):
value = getattr(m,attr)
if isinstance(value,torch.Tensor) and value is not None:
getattr(new_m,attr).data.copy_(value.data)
- elif isinstance(value,torch.nn.Module) and value is not None:
+ elif isinstance(value,torch.nn.Module):
setattr(new_m, attr, getattr(m,attr))
setattr(new_m, attr, getattr(m,attr))
- elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
+ elif attr in copy_attr_list:
+ # copy attributes that need to be copied
setattr(new_m, attr, getattr(m, attr))
#
#
setattr(new_m, attr, getattr(m, attr))
#
#
qparams = get_qparams()
qparams.inputs.append(x)
qparams.modules.append(self)
qparams = get_qparams()
qparams.inputs.append(x)
qparams.modules.append(self)
- y.qparams = qparams
+ if hasattr(x, 'clips_act'):
+ qparams.clips_input = x.clips_act
#
#
+ y.qparams = qparams
return y
#
return y
#
qparams = get_qparams()
qparams.inputs.append(x)
qparams.modules.append(self)
qparams = get_qparams()
qparams.inputs.append(x)
qparams.modules.append(self)
- y.qparams = qparams
+ if hasattr(x, 'clips_act'):
+ qparams.clips_input = x.clips_act
#
#
+ y.qparams = qparams
return y
#
return y
#
+###########################################################
+class QuantTrainLinear(torch.nn.Linear):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.quantize_enable = True
+ self.bitwidth_weights = None
+ self.bitwidth_activations = None
+ self.per_channel_q = False
+
+ def forward(self, x):
+ is_merged = is_merged_layer(x)
+ if is_merged:
+ warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
+ #
+
+ y = super().forward(x)
+
+ if not self.quantize_enable:
+ # if quantization is disabled - return
+ return y
+ #
+
+ qparams = get_qparams()
+ qparams.inputs.append(x)
+ qparams.modules.append(self)
+ if hasattr(x, 'clips_act'):
+ qparams.clips_input = x.clips_act
+ #
+ y.qparams = qparams
+ return y
+ #
+
+
###########################################################
class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
###########################################################
class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
qparams = get_qparams()
qparams.inputs = [x.qparams.inputs[0], x]
qparams.modules = [x.qparams.modules[0], self]
qparams = get_qparams()
qparams.inputs = [x.qparams.inputs[0], x]
qparams.modules = [x.qparams.modules[0], self]
+ if hasattr(x.qparams, 'clips_input'):
+ qparams.clips_input = x.qparams.clips_input
+ #
y.qparams = qparams
#
y.qparams = qparams
#
###########################################################
# fake quantized PAct2 for training
class QuantTrainPAct2(layers.PAct2):
###########################################################
# fake quantized PAct2 for training
class QuantTrainPAct2(layers.PAct2):
- def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
- super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
+ def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None,
+ per_channel_q=False, percentile_range_shrink=layers.PAct2.PACT2_RANGE_SHRINK, power2_weight_range=True, power2_activation_range=True):
+ super().__init__(inplace=inplace, signed=signed, clip_range=clip_range, percentile_range_shrink=percentile_range_shrink,
+ power2_activation_range=power2_activation_range)
self.bitwidth_weights = bitwidth_weights
self.bitwidth_activations = bitwidth_activations
self.per_channel_q = per_channel_q
self.bitwidth_weights = bitwidth_weights
self.bitwidth_activations = bitwidth_activations
self.per_channel_q = per_channel_q
+ self.power2_weight_range = power2_weight_range
+
# weight shrinking is done by clamp weights - set this factor to zero.
# this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
# so any clipping we do here is not stored int he weight params
self.range_shrink_weights = 0.0
self.round_dither = 0.0
# weight shrinking is done by clamp weights - set this factor to zero.
# this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
# so any clipping we do here is not stored int he weight params
self.range_shrink_weights = 0.0
self.round_dither = 0.0
- self.update_range = True
+ self.update_activation_range = True
self.quantize_enable = True
self.quantize_weights = True
self.quantize_bias = True
self.quantize_activations = True
self.quantize_enable = True
self.quantize_weights = True
self.quantize_bias = True
self.quantize_activations = True
+ self.constrain_bias = None
self.constrain_weights = True
self.bias_calibration = False
self.constrain_weights = True
self.bias_calibration = False
- # save quantized weight/bias once in a while into the params - not needed
- self.params_save_frequency = None #(10 if self.bias_calibration else None)
+ # do joint quantization only after the activation range has stabilized reasonably.
+ self.constrain_bias_start_iter = 75
# set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
# For a comparison of STE and ABE, read:
# set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
# For a comparison of STE and ABE, read:
def forward(self, x):
assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
'bitwidth_weights and bitwidth_activations must not be None'
def forward(self, x):
assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
'bitwidth_weights and bitwidth_activations must not be None'
-
# the pact range update happens here - but range clipping depends on quantize_enable
# the pact range update happens here - but range clipping depends on quantize_enable
- y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
+ y = super().forward(x, update_activation_range=self.update_activation_range, enable=self.quantize_enable)
if not self.quantize_enable:
return y
if not self.quantize_enable:
return y
conv, bn = None, None
# merge weight and bias (if possible) across layers
conv, bn = None, None
# merge weight and bias (if possible) across layers
- if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
+ if len(qparams.modules) == 2 and utils.is_conv_deconv_linear(qparams.modules[-2]) and isinstance(
qparams.modules[-1], torch.nn.BatchNorm2d):
conv = qparams.modules[-2]
bn = qparams.modules[-1]
qparams.modules[-1], torch.nn.BatchNorm2d):
conv = qparams.modules[-2]
bn = qparams.modules[-1]
- elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
+ elif len(qparams.modules) == 1 and utils.is_conv_deconv_linear(qparams.modules[-1]):
conv = qparams.modules[-1]
elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
conv = qparams.modules[-1]
elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
else:
assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
#
else:
assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
#
- conv, weight, bias = self.merge_quantize_weights(conv, bn)
+ conv, weight, bias = self.merge_quantize_weights(qparams, conv, bn)
else:
conv, weight, bias = None, None, None
#
else:
conv, weight, bias = None, None, None
#
elif is_merged and utils.is_deconv(conv):
xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
dilation=conv.dilation, groups=conv.groups)
elif is_merged and utils.is_deconv(conv):
xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
dilation=conv.dilation, groups=conv.groups)
+ elif is_merged and utils.is_linear(conv):
+ xq = torch.nn.functional.linear(xorg, weight, bias)
else:
xq = x
#
else:
xq = x
#
# currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
# see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
# yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
# currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
# see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
# yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
- yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
+ yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 'round_up')
else:
else:
- yq = super().forward(xq, update_range=False, enable=True)
+ yq = super().forward(xq, update_activation_range=False, enable=True)
#
if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
#
if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
#
assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
#
+ # pass on the clips to be used in the next quantization
+ y.clips_act = self.get_clips_act()
return y
#
return y
#
return quant_utils.constrain_weight(merged_weight)
return quant_utils.constrain_weight(merged_weight)
- def merge_quantize_weights(self, conv, bn):
- # store the quantized weights and biases in-frequently - otherwise learning will be poor
- # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
- first_training_iter = self.training and (self.num_batches_tracked == 0)
- is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
+ def merge_quantize_weights(self, qparams, conv, bn):
+ num_batches_tracked = int(self.num_batches_tracked)
+ is_constrain_weights_iter = self.training and (num_batches_tracked == 0)
+ is_store_weights_iter = self.training and (num_batches_tracked == 0)
+ is_constrain_bias_iter = self.training and (num_batches_tracked>=self.constrain_bias_start_iter)
+ is_store_bias_iter = self.training and (num_batches_tracked==self.constrain_bias_start_iter)
# merge weight and bias (if possible) across layers
if conv is not None and bn is not None:
# merge weight and bias (if possible) across layers
if conv is not None and bn is not None:
# quantize weight and bias
if (conv is not None):
if (self.quantize_enable and self.quantize_weights):
# quantize weight and bias
if (conv is not None):
if (self.quantize_enable and self.quantize_weights):
- if self.constrain_weights and first_training_iter:
+ if self.constrain_weights and is_constrain_weights_iter:
with torch.no_grad():
# clamp merged weights, invert the bn and copy to conv weight
constrained_weight = self.apply_constrain_weights(merged_weight.data)
with torch.no_grad():
# clamp merged weights, invert the bn and copy to conv weight
constrained_weight = self.apply_constrain_weights(merged_weight.data)
#
width_min, width_max = self.get_widths_w()
# merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
#
width_min, width_max = self.get_widths_w()
# merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
- merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
+ merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2_weight_range, 'round_sym')
#
if (self.quantize_enable and self.quantize_bias):
bias_width_min, bias_width_max = self.get_widths_bias()
bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
#
if (self.quantize_enable and self.quantize_bias):
bias_width_min, bias_width_max = self.get_widths_bias()
bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
- # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
- merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
+ power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
+ merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+ #
+
+ # in some cases, bias quantization can have additional restrictions if for example,
+ # bias that is being added to accumulator is limited to 16bit.
+ # scale factor to be used for bias is the product of scale factors of weight and input
+ if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter:
+ clips_scale_joint = self.get_clips_scale_joint(qparams, merged_weight, merged_bias)
+ if clips_scale_joint is not None:
+ bias_width_min, bias_width_max = self.get_widths_joint()
+ bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = clips_scale_joint
+ power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
+ merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+ #
#
# invert the bn operation and store weights/bias
#
# invert the bn operation and store weights/bias
- if first_training_iter or (self.training and is_store_weight_bias_iter):
- with torch.no_grad():
- if self.quantize_enable and self.quantize_weights:
- conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
- #
- if self.quantize_enable and self.quantize_bias:
- if conv.bias is not None:
- if bn is not None:
- conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
- conv.bias.data.copy_(conv_bias.data)
- else:
- conv.bias.data.copy_(merged_bias.data)
- #
- elif bn is not None and bn.bias is not None:
- bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
- bn.bias.data.copy_(bn_bias.data)
- #
+ if self.quantize_enable and self.quantize_weights and is_store_weights_iter:
+ conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
+ #
+ if self.quantize_enable and self.quantize_bias and is_store_bias_iter:
+ if conv.bias is not None:
+ if bn is not None:
+ conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
+ conv.bias.data.copy_(conv_bias.data)
+ else:
+ conv.bias.data.copy_(merged_bias.data)
#
#
+ elif bn is not None and bn.bias is not None:
+ bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
+ bn.bias.data.copy_(bn_bias.data)
#
#
#
return conv, merged_weight, merged_bias
#
#
#
return conv, merged_weight, merged_bias
+ ###########################################################
+ def get_widths_w(self):
+ # weights
+ bw = (self.bitwidth_weights - 1)
+ width_max = np.power(2.0, bw)
+ width_min = -width_max
+ # return
+ return (width_min, width_max)
+
+
def get_clips_w(self, tensor):
# find the clip values
w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
clip_max = torch.clamp(clip_max, min=self.eps)
def get_clips_w(self, tensor):
# find the clip values
w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
clip_max = torch.clamp(clip_max, min=self.eps)
- # in range learning mode + training - this power2 is taken care in the quantize function
- use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
+ # in range learning mode + training - this power2_weight_range is taken care in the quantize function
+ use_power2 = (self.power2_weight_range and (not (self.PACT2_RANGE_LEARN and self.training)))
clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
clip_min2 = -clip_max2
return (clip_min2, clip_max2)
clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
clip_min2 = -clip_max2
return (clip_min2, clip_max2)
- # bias uses the same kind of clips
- get_clips_bias = get_clips_w
-
def get_clips_scale_w(self, weight):
def get_clips_scale_w(self, weight):
- # convert to scale
clip_min, clip_max = self.get_clips_w(weight)
width_min, width_max = self.get_widths_w()
scale2 = (width_max / clip_max)
clip_min, clip_max = self.get_clips_w(weight)
width_min, width_max = self.get_widths_w()
scale2 = (width_max / clip_max)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
+ ###########################################################
+ def get_widths_act(self):
+ if self.signed is None:
+ clip_min, clip_max = self.get_clips_act()
+ signed = (clip_min < 0.0)
+ else:
+ signed = self.signed
+ #
+ bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
+ width_max = np.power(2.0, bw)
+ width_min = -width_max if signed else 0.0
+ return width_min, width_max
- # in reality, bias quantization will also depend on the activation scale
- # this is not perfect - just a quick and dirty quantization for bias
- def get_clips_scale_bias(self, bias):
- # convert to scale
- clip_min, clip_max = self.get_clips_bias(bias)
- width_min, width_max = self.get_widths_bias()
- scale2 = (width_max / clip_max)
+
+ def get_clips_scale_act(self):
+ clip_min, clip_max = self.get_clips_act()
+ width_min, width_max = self.get_widths_act()
+ scale2 = width_max / clip_max
scale2 = torch.clamp(scale2, min=self.eps)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
scale2 = torch.clamp(scale2, min=self.eps)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
- def get_widths_w(self):
- # weights
- bw = (self.bitwidth_weights - 1)
- width_max = np.power(2.0, bw)
- width_min = -width_max
- # return
- return (width_min, width_max)
+ ###########################################################
+ # bias uses the same kind of widths
+ get_widths_bias = get_widths_w
- def get_widths_bias(self):
- # bias
- bitwidth_bias = (2*self.bitwidth_activations)
- bias_width_max = np.power(2.0, bitwidth_bias-1)
- bias_width_min = -bias_width_max
- # return
- return (bias_width_min, bias_width_max)
+ # bias uses the same kind of clips
+ get_clips_bias = get_clips_w
- # activation utility functions
- def get_clips_scale_act(self):
- # convert to scale
- clip_min, clip_max = self.get_clips_act()
- width_min, width_max = self.get_widths_act()
- scale2 = width_max / clip_max
+ def get_clips_scale_bias(self, bias):
+ clip_min, clip_max = self.get_clips_bias(bias)
+ width_min, width_max = self.get_widths_bias()
+ scale2 = (width_max / clip_max)
scale2 = torch.clamp(scale2, min=self.eps)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
scale2 = torch.clamp(scale2, min=self.eps)
scale_inv2 = scale2.pow(-1.0)
return (clip_min, clip_max, scale2, scale_inv2)
- def get_widths_act(self):
- if self.signed is None:
- clip_min, clip_max = self.get_clips_act()
- signed = (clip_min < 0.0)
+ ###########################################################
+ def get_widths_joint(self):
+ bw = (2*self.bitwidth_weights - 1)
+ width_max = np.power(2.0, bw)
+ width_min = -width_max
+ return (width_min, width_max)
+
+
+ def get_clips_input(self, qparams):
+ if hasattr(qparams, 'clips_input'):
+ return qparams.clips_input
else:
else:
- signed = self.signed
+ return None
#
#
+
+ def get_widths_input(self, clip_min, clip_max):
+ signed = (clip_min < 0.0)
bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
width_max = np.power(2.0, bw)
width_min = -width_max if signed else 0.0
return width_min, width_max
bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
width_max = np.power(2.0, bw)
width_min = -width_max if signed else 0.0
return width_min, width_max
+
+ def get_clips_scale_input(self, qparams):
+ clips_input = self.get_clips_input(qparams)
+ if clips_input is not None:
+ clip_min, clip_max = clips_input
+ width_min, width_max = self.get_widths_input(clip_min, clip_max)
+ scale2 = (width_max / clip_max)
+ scale2 = torch.clamp(scale2, min=self.eps)
+ scale_inv2 = scale2.pow(-1.0)
+ return (clip_min, clip_max, scale2, scale_inv2)
+ else:
+ return None
+ #
+
+
+ def get_clips_scale_joint(self, qparams, weights, bias):
+ clips_scale_input = self.get_clips_scale_input(qparams)
+ if clips_scale_input is not None:
+ clip_min_input, clip_max_input, scale2_input, scale_inv2_input = clips_scale_input
+ clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights)
+ clip_min_bias, clip_max_bias, scale2_bias, scale_inv2_bias = self.get_clips_scale_bias(bias)
+ return (clip_min_bias, clip_max_bias, scale2_w*scale2_input, scale_inv2_w*scale_inv2_input)
+ else:
+ return None
+ #
\ No newline at end of file
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py b/modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
index eab5d3f043bb1c2703b406323ed880be1a5c3102..3bc9743abb5204804cf0078bc3977b1fbfc52904 100644 (file)
######################################################
def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
######################################################
def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
+ download_root = './' if (download_root is None) else download_root
if pretrained is None or pretrained is False:
print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
return model
if pretrained is None or pretrained is False:
print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
return model
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
index c12a26fdf73109f70d14f55452aa8f797a43ca16..9fcd6e9198eae2c3079e6de22c7f38abc3f76d20 100644 (file)
def is_activation(module):
is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
def is_activation(module):
is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
- layers.NoAct, layers.PAct2))
+ layers.PAct2, layers.QAct, layers.NoQAct))
return is_act
def is_pact2(module):
return is_act
def is_pact2(module):
def is_conv_deconv(module):
return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
def is_conv_deconv(module):
return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
+def is_conv_deconv_linear(module):
+ return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))
+
def is_linear(module):
return isinstance(module, torch.nn.Linear)
def is_linear(module):
return isinstance(module, torch.nn.Linear)
def squeeze_list(inputs):
def squeeze_list(inputs):
+ return inputs[0] if (is_list(inputs) and len(inputs)==1) else inputs
+
+
+def squeeze_list2(inputs):
return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
diff --git a/run_quantization.sh b/run_quantization.sh
index 3f80ed99cf69ffbbdd738e8ae6b84155e2e0d0f8..461dbc4bb0147bc9c5b506770bb411b351b9ee32 100755 (executable)
--- a/run_quantization.sh
+++ b/run_quantization.sh
-# Quantization
-
-## =====================================================================================
-## Quantization Aware Training
-## =====================================================================================
-#
-#### Image Classification - Quantization Aware Training - MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+## Quantization
#
#
-#
-#### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
-#
-#
-#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
-#
-#
-#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
-
-
-
## =====================================================================================
## Post Training Calibration & Quantization - this is fast, but may not always yield best quantized accuracy (not recommended)
## =====================================================================================
## =====================================================================================
## Post Training Calibration & Quantization - this is fast, but may not always yield best quantized accuracy (not recommended)
## =====================================================================================
#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
#--batch_size 6 --quantize True --epochs 1 --evaluate_start False
#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
#--batch_size 6 --quantize True --epochs 1 --evaluate_start False
-
-
-
+#
+#
+## =====================================================================================
+## Quantization Aware Training
+## =====================================================================================
+#
+#### Image Classification - Quantization Aware Training - MobileNetV2
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
+#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#
+#
+#### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
+#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#
+#
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#
+#
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#
+#
+#
## =====================================================================================
## =====================================================================================
-## Acuracy Evaluation with Post Training Quantization - cannot save quantized model - only accuracy evaluation
+## Acuracy Evaluation with Post Training Quantization - this is not supported anymore.
+## Either Calibration or QAT has to be performed first, to get correct accuracy.
+## Please use one of the sections above.
## =====================================================================================
## =====================================================================================
-
+#
#### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
#--batch_size 64 --quantize True
#### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
#--batch_size 64 --quantize True
-
+#
#### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
#--batch_size 64 --quantize True
#### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
#--batch_size 64 --quantize True
-
+#
#### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
#--batch_size 64 --quantize True
#### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
#--batch_size 64 --quantize True
-
+#
#### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
#--batch_size 64 --quantize True
#### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
#--batch_size 64 --quantize True
-
+#
#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
#--batch_size 1 --quantize True
#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
#--batch_size 1 --quantize True
-
+#
#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \