From: Manu Mathew Date: Tue, 19 May 2020 17:13:48 +0000 (+0530) Subject: support DataParallel for QuantTrainModule X-Git-Url: https://git.ti.com/gitweb?p=jacinto-ai%2Fpytorch-jacinto-ai-devkit.git;a=commitdiff_plain;h=8273351c0f889420acd8e08627ebb2c1f1ddb687 support DataParallel for QuantTrainModule quantization docs update release commit --- diff --git a/docs/Calibration.md b/docs/Calibration.md index 385c0c7..0d1098c 100644 --- a/docs/Calibration.md +++ b/docs/Calibration.md @@ -90,3 +90,6 @@ python ./scripts/train_segmentation_main.py --phase calibration --dataset_name c --batch_size 12 --quantize True --epochs 1 ``` +## Guidelines, Implementation Notes, Limitations & Recommendations +- Please refer to the section on Quantization Aware Training, as the same guidelines, recomendations & limitations apply to QuantCalibrateModule.
+- An additional limitation is that multi gpu processing with DataParallel / DistributedDataParallel is not supported for QuantCalibrateModule (also for QuantTestModule). In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that. (However multi gpu support with DataParallel works for QuantTrainModule - more details of this in the QAT section).
diff --git a/docs/Quantization.md b/docs/Quantization.md index 803b2a6..5ae5531 100644 --- a/docs/Quantization.md +++ b/docs/Quantization.md @@ -29,9 +29,8 @@ To get best accuracy at the quantization stage, it is important that the model i - **The same module should not be re-used multiple times within the module** in order that the feature map range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](../modules/pytorch_jacinto_ai/vision/models/resnet.py).
- If you have done QAT and is getting poor accuracy either in the Python code or during inference in the platform, please inspect your model carefully to see if the above recommendations have been followed - some of these can be easily missed by oversight - and can result in painful debugging that could have been avoided.
- However, if a function does not change the range of feature map, it is not critical to use it in Module form. An example of this is torch.nn.functional.interpolate
-- **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantTrainModule/QuantCalibrateModule/QuantTestModule. We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantTrainModule/QuantCalibrateModule/QuantTestModule.
-- If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantTrainModule/QuantCalibrateModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original floating point training. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.
-- If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again. +- **Multi-GPU training/calibration/validation with DataParallel is supported with our QAT module** QuantTrainModule. This takes care of a major concern that was earlier there in doing QAT with QuantTrainModule. (However it is not supported for QuantCalibrateModule/QuantTestModule - these calibration/test phases take much less time - so hopefully this is not a big issue. In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule, but we do that for QuantTrainModule).
+- If your training/calibration crashes because of insufficient GPU memory, reduce the batch size and try again. - This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form. - If you are using TIDL to infer a model trained using QAT (or calibratied using PTQ) tools provided in this repository, please set the following in the import config file for best accuracy: **quantizationStyle = 3** to use power of 2 quantization. **foldPreBnConv2D = 0** to avoid a slight accuracy degradation due to incorrect folding of BatchNormalization that comes before Convolution (input mean/scale is implemented in TIDL as a PreBN - so this affects most networks). diff --git a/modules/pytorch_jacinto_ai/engine/train_classification.py b/modules/pytorch_jacinto_ai/engine/train_classification.py index 89f08b6..609b134 100644 --- a/modules/pytorch_jacinto_ai/engine/train_classification.py +++ b/modules/pytorch_jacinto_ai/engine/train_classification.py @@ -282,12 +282,13 @@ def main(args): exit() ################################################# - # multi gpu mode is not working for quantized model - if args.parallel_model and (not args.quantize): + # DataParallel does not work for QuantCalibrateModule or QuantTestModule + if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))): if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model) else: model = torch.nn.DataParallel(model) + # # ################################################# diff --git a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py index d838078..6e9e623 100644 --- a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py +++ b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py @@ -401,9 +401,8 @@ def main(args): exit() ################################################# - # multi gpu mode does not work for calibration/training for quantization - # so use it only when args.quantize is False - if args.parallel_model and ((not args.quantize)): + # DataParallel does not work for QuantCalibrateModule or QuantTestModule + if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))): model = torch.nn.DataParallel(model) ################################################# diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py index 39c6482..3f40d38 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py @@ -7,13 +7,10 @@ from ..utils import AttrDict as Dict from .hooked_module import * class QuantGraphModule(HookedModule): - # instance member states are not retained across forward calls when using DataParallel - # as a workaround, use class member variables instead, so that these can be retained - states = Dict() def __init__(self, module): super().__init__() self.module = module - self.init_states() + self.init_qstate() self.register_buffer('num_batches_tracked', torch.tensor(-1.0)) self.register_buffer('iter_in_epoch', torch.tensor(-1.0)) self.register_buffer('epoch', torch.tensor(-1.0)) @@ -28,62 +25,28 @@ class QuantGraphModule(HookedModule): # # - # create the state object required to keep some quantization parameters that need to be preserved - # a cuda() has been called on the module - copy the states from that was created for cpu - def get_state(self): - states = self.get_states() - module_device = self.module_device(self.module) - if module_device not in states: - module_device_src = None - for key, value in states.items(): - if key.type == 'cpu': - module_device_src = key - # - if module_device_src is not None and module_device_src in states: - states[module_device] = copy.deepcopy(states[module_device_src]) - else: - states[module_device] = Dict() + def init_qstate(self): + if not hasattr(self, '__qstate__'): + self.__qstate__ = Dict() + # + if 'qparams' not in self.get_qstate(): + self.get_qstate().qparams = Dict() + # + if 'qparams_prev' not in self.get_qstate(): + self.get_qstate().qparams_prev = Dict() + # + if 'analyzed_graph' not in self.get_qstate(): + self.get_qstate().analyzed_graph = False # - return states[module_device] - - - def get_states(self): - return __class__.states - - def clear_states(self): - __class__.states = Dict() - # these entries will prevent this modules from being used with DataParallel - cleanup - def cleanup_states(self): - assert self.get_state().analyzed_graph == True, 'graph must be analyzed before cleanup_states()' - with torch.no_grad(): - for module_hash, qparams in self.get_state().qparams.items(): - if hasattr(qparams, 'previous_node'): - del qparams.previous_node - # - if hasattr(qparams, 'previous_module'): - del qparams.previous_module - # - if hasattr(qparams, 'next_node'): - del qparams.next_node - # - if hasattr(qparams, 'next_module'): - del qparams.next_module - # - # - # + def clear_qstate(self): + self.__qstate__ = Dict() + self.init_qstate() - # data parallel does not initialize the replicas correctly, explicitly initialize them. - # there is no use in doing this in __init__. it has to be done in forward even if it is called in __init__. - def init_states(self): - if 'qparams' not in self.get_state().keys(): - self.get_state().qparams = Dict() - if 'qparams_prev' not in self.get_state().keys(): - self.get_state().qparams_prev = Dict() - if 'analyzed_graph' not in self.get_state().keys(): - self.get_state().analyzed_graph = False + def get_qstate(self): + return self.__qstate__ def forward(self, inputs): @@ -102,24 +65,24 @@ class QuantGraphModule(HookedModule): # force_update is used to increment inte counters even in non training # used for validation in QuantTestModule - def analyze_graph(self, inputs, force_update=False, merge_weights=False, cleanup_states=False): + def analyze_graph(self, inputs, force_update=False, merge_weights=False, clear_qstate=False): with torch.no_grad(): - self.init_states() + self.init_qstate() self.update_counters(force_update=force_update) - if (self.get_state().analyzed_graph == False): + if (self.get_qstate().analyzed_graph == False): # forward and analyze self.forward_analyze_modules(inputs) # analyze the connections self.analyze_connections() - self.get_state().analyzed_graph = True + self.get_qstate().analyzed_graph = True # merge weights so that weight quantization can be done if merge_weights: self.merge_weights() # - if cleanup_states: - self.cleanup_states() + if clear_qstate: + self.clear_qstate() # # # @@ -127,18 +90,18 @@ class QuantGraphModule(HookedModule): def model_surgery_quantize(self, dummy_input): # lear the sates - just to be sure - self.clear_states() + self.clear_qstate() # analyze self.analyze_graph(dummy_input) # insert NoAct wherever range clipping needs to be done self.model_surgery_activations() # since we might have added new activations, clear the sates as they may not be valid - self.clear_states() + self.clear_qstate() # need to call analyze_graph in the derived class # def model_surgery_activations(self): - for module_hash, qparams in self.get_state().qparams.items(): + for module_hash, qparams in self.get_qstate().qparams.items(): module = self.get_module(module_hash) if isinstance(module, layers.PAct2): pass @@ -170,14 +133,6 @@ class QuantGraphModule(HookedModule): # - def start_call(self): - self.call_count = Dict() - - - def finish_call(self): - self.call_count = None - - def train(self, mode=True): self.iter_in_epoch.fill_(-1.0) super().train(mode) @@ -185,6 +140,15 @@ class QuantGraphModule(HookedModule): ################################################################ def forward_analyze_modules(self, inputs): + ''' + analyze modules needs a call hook - the call hook does not work with DataParallel. + So, do the analysis on a copy. + ''' + self_copy = copy.deepcopy(self) + self_copy._forward_analyze_modules_impl(inputs) + self.get_qstate().qparams = self_copy.get_qstate().qparams + + def _forward_analyze_modules_impl(self, inputs): self.start_call() self.add_call_hook(self.module, self._analyze_modules_op) output = self.module(inputs) @@ -205,29 +169,29 @@ class QuantGraphModule(HookedModule): inputs = self.format_tensors(inputs) module_hash = self.module_hash(module) - if module_hash not in list(self.get_state().qparams.keys()): - self.get_state().qparams[module_hash] = Dict() - self.get_state().qparams[module_hash].qrange_w = None - self.get_state().qparams[module_hash].qrange_b = None - self.get_state().qparams[module_hash].qrange_in = [] - self.get_state().qparams[module_hash].qrange_out = [] - self.get_state().qparams[module_hash].is_input = (self.module is module) - self.get_state().qparams[module_hash].previous_node = [] - self.get_state().qparams[module_hash].next_node = [] - self.get_state().qparams[module_hash].current_node = module_hash - - current_node = self.get_state().qparams[module_hash].current_node + if module_hash not in list(self.get_qstate().qparams.keys()): + self.get_qstate().qparams[module_hash] = Dict() + self.get_qstate().qparams[module_hash].qrange_w = None + self.get_qstate().qparams[module_hash].qrange_b = None + self.get_qstate().qparams[module_hash].qrange_in = [] + self.get_qstate().qparams[module_hash].qrange_out = [] + self.get_qstate().qparams[module_hash].is_input = (self.module is module) + self.get_qstate().qparams[module_hash].previous_node = [] + self.get_qstate().qparams[module_hash].next_node = [] + self.get_qstate().qparams[module_hash].current_node = module_hash + + current_node = self.get_qstate().qparams[module_hash].current_node for inp in inputs: if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'): prev_module_hash = inp.qparams.last_node prev_module = self.get_module(prev_module_hash) - previous_node = self.get_state().qparams[module_hash].previous_node - next_node = self.get_state().qparams[prev_module_hash].next_node + previous_node = self.get_qstate().qparams[module_hash].previous_node + next_node = self.get_qstate().qparams[prev_module_hash].next_node if str(inp.qparams.last_node) not in [str(p) for p in previous_node]: - self.get_state().qparams[module_hash].previous_node += [inp.qparams.last_node] + self.get_qstate().qparams[module_hash].previous_node += [inp.qparams.last_node] if str(current_node) not in [str(n) for n in next_node]: - self.get_state().qparams[prev_module_hash].next_node += [current_node] + self.get_qstate().qparams[prev_module_hash].next_node += [current_node] if outputs is not None: outputs = self.format_tensors(outputs) @@ -245,13 +209,13 @@ class QuantGraphModule(HookedModule): ################################################################ def analyze_connections(self): prediction_module = None - for module_hash, qparams in self.get_state().qparams.items(): + for module_hash, qparams in self.get_qstate().qparams.items(): module = self.get_module(module_hash) if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module): prediction_module = module # # - for module_hash, qparams in self.get_state().qparams.items(): + for module_hash, qparams in self.get_qstate().qparams.items(): module = self.get_module(module_hash) is_prediction = (prediction_module is module) self._analyse_connections_op(module_hash, module, qparams, is_prediction) @@ -310,9 +274,9 @@ class QuantGraphModule(HookedModule): ################################################################ def merge_weights(self, make_backup=False): - assert self.get_state().analyzed_graph == True, 'graph must be analyzed before merge_weights()' + assert self.get_qstate().analyzed_graph == True, 'graph must be analyzed before merge_weights()' with torch.no_grad(): - for module_hash, qparams in self.get_state().qparams.items(): + for module_hash, qparams in self.get_qstate().qparams.items(): module = self.get_module(module_hash) self._merge_weight_op(module_hash, module, qparams, make_backup) # @@ -367,52 +331,62 @@ class QuantGraphModule(HookedModule): warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine') # # - return ################################################################ def get_qparams(self, module): module_hash = self.module_hash(module) - return self.get_state().qparams[module_hash] + return self.get_qstate().qparams[module_hash] + def get_qparams_prev(self, module): module_hash = self.module_hash(module) - return self.get_state().qparams_prev[module_hash] if self.get_state().qparams_prev else None + return self.get_qstate().qparams_prev[module_hash] if self.get_qstate().qparams_prev else None + + + def start_call(self): + self.call_count = Dict() + + + def finish_call(self): + self.call_count = None + def start_node(self, module): module_name = self.module_name(module) if module_name not in list(self.call_count.keys()): self.call_count[module_name] = 0 - + # return + def finish_node(self, module, inputs, outputs): module_name = self.module_name(module) self.call_count[module_name] = self.call_count[module_name] + 1 return + def module_hash(self, module): + ''' + A module may be called multiple times in a model. This module has creates a unique name/hash for each call + using teh call_count. call_count needs tobe correct for this to work as expected. + call_count is kep up to date by using start_node() / finish_node() calls. + ''' module_name = self.module_name(module) module_hash = module_name + '-call:{}'.format(self.call_count[module_name]) return module_hash + def module_name(self, module): name = None for n, m in self.named_modules(): if m is module: name = n + # # return name - def module_device(self, module=None): - module = module if module is not None else self.module - try: - module_device = next(module.parameters()).device - except: - module_device = None - # - return module_device def get_module(self, module_hash): module_name = module_hash.split('-call:')[0] @@ -422,10 +396,12 @@ class QuantGraphModule(HookedModule): # return None + def is_last_conv(self, module): # implementation is not correct. disable it for the time being return False #(module is self.last_conv_linear_module) + def format_tensors(self, inputs): # make a list/tuple if inputs is not. if it is a double list, remove the extra one inputs = utils.squeeze_list(utils.make_list(inputs)) @@ -436,15 +412,16 @@ class QuantGraphModule(HookedModule): def copy_qparams(self, qparams, inputs): qparams_copy = Dict() - - # deep copy may not work in some cases, so do it conditionally for module_hash, qparam_entry in qparams.items(): qparams_copy[module_hash] = Dict() for key, value in qparam_entry.items(): + # deep copy may not work in some cases, so do it conditionally try: qparams_copy[module_hash][key] = copy.deepcopy(value) except Exception: qparams_copy[module_hash][key] = value - + # + # + # return qparams_copy diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py index 55f2301..64bce32 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py @@ -66,13 +66,13 @@ class QuantTestModule(QuantBaseModule): self.apply(replace_func) # clear - self.clear_states() + self.clear_qstate() # def forward(self, inputs): # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters() - self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True, cleanup_states=True) + self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True) # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0) # if batch_size != 1: @@ -84,7 +84,7 @@ class QuantTestModule(QuantBaseModule): # quantize outputs = self.forward_quantize(inputs) # start and new frame, copy the qparams for previous frame of inference - self.get_state().qparams_prev = self.copy_qparams(self.get_state().qparams, inputs) + self.get_qstate().qparams_prev = self.copy_qparams(self.get_qstate().qparams, inputs) # return return outputs # diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py index 7fe340e..e7601cd 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py @@ -91,7 +91,7 @@ class QuantTrainModule(QuantBaseModule): self.apply(replace_func) # clear - self.clear_states() + self.clear_qstate() # diff --git a/scripts/train_classification_main.py b/scripts/train_classification_main.py index e7e01f1..19b364c 100755 --- a/scripts/train_classification_main.py +++ b/scripts/train_classification_main.py @@ -171,9 +171,6 @@ if 'training' in args.phase and (not args.quantize): args.quantize = True args.lr = 1e-5 args.epochs = 50 - # quantized training will use only one GPU in the engine - so reduce the batch_size - num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) - args.batch_size = args.batch_size//num_gpus train_classification.main(args) # @@ -182,6 +179,12 @@ if 'training' in args.phase and (not args.quantize): if 'training' in args.phase or 'calibration' in args.phase: save_path = train_classification.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth') + if 'training' in args.phase: + # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule. + # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it. + num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None + args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size + # args.phase = 'validation' args.quantize = True train_classification.main(args) diff --git a/scripts/train_motion_segmentation_main.py b/scripts/train_motion_segmentation_main.py index 2f13e15..1904225 100755 --- a/scripts/train_motion_segmentation_main.py +++ b/scripts/train_motion_segmentation_main.py @@ -151,9 +151,6 @@ if 'training' in args.phase and (not args.quantize): args.quantize = True args.lr = 1e-5 args.epochs = 50 - # quantized training will use only one GPU in the engine - so reduce the batch_size - num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None - args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size train_pixel2pixel.main(args) # @@ -162,6 +159,12 @@ if 'training' in args.phase and (not args.quantize): if 'training' in args.phase or 'calibration' in args.phase: save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth') + if 'training' in args.phase: + # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule. + # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it. + num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None + args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size + # args.phase = 'validation' args.quantize = True train_pixel2pixel.main(args) diff --git a/scripts/train_pixel2pixel_multitask_main.py b/scripts/train_pixel2pixel_multitask_main.py index a505500..802b5ff 100755 --- a/scripts/train_pixel2pixel_multitask_main.py +++ b/scripts/train_pixel2pixel_multitask_main.py @@ -163,9 +163,6 @@ if 'training' in args.phase and (not args.quantize): args.quantize = True args.lr = 1e-5 args.epochs = 50 - # quantized training will use only one GPU in the engine - so reduce the batch_size - num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None - args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size train_pixel2pixel.main(args) # @@ -174,6 +171,12 @@ if 'training' in args.phase and (not args.quantize): if 'training' in args.phase or 'calibration' in args.phase: save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') + if 'training' in args.phase: + # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule. + # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it. + num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None + args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size + # args.phase = 'validation' args.quantize = True train_pixel2pixel.main(args) diff --git a/scripts/train_segmentation_main.py b/scripts/train_segmentation_main.py index a87dd91..f85b5f0 100755 --- a/scripts/train_segmentation_main.py +++ b/scripts/train_segmentation_main.py @@ -166,9 +166,6 @@ if 'training' in args.phase and (not args.quantize): args.quantize = True args.lr = 1e-5 args.epochs = 50 - # quantized training will use only one GPU in the engine - so reduce the batch_size - num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None - args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size train_pixel2pixel.main(args) # @@ -177,6 +174,12 @@ if 'training' in args.phase and (not args.quantize): if 'training' in args.phase or 'calibration' in args.phase: save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth') + if 'training' in args.phase: + # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule. + # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it. + num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None + args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size + # args.phase = 'validation' args.quantize = True train_pixel2pixel.main(args)