[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_classification.py
diff --git a/modules/pytorch_jacinto_ai/engine/train_classification.py b/modules/pytorch_jacinto_ai/engine/train_classification.py
index bc208a4f6e4d0d88b69ed9cd589b898713d2d7bf..9bfde8e6b7bcfc440167b5d69862096917671d0d 100644 (file)
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
import os
import shutil
import time
args = xnn.utils.ConfigNode()
args.model_config = xnn.utils.ConfigNode()
args.dataset_config = xnn.utils.ConfigNode()
+ args.model_config.input_channels = 3 # num input channels
+ args.model_config.output_type = 'classification'
+ args.model_config.output_channels = None
+ args.model_config.strides = None # (2,2,2,2,2)
args.model_config.num_tiles_x = int(1)
args.model_config.num_tiles_y = int(1)
args.model_config.en_make_divisible_by8 = True
- args.model_config.input_channels = 3 # num input channels
+ args.model_config.enable_fp16 = False # FP16 half precision mode
args.input_channel_reverse = False # rgb to bgr
args.data_path = './data/datasets/ilsvrc' # path to dataset
args.model_name = 'mobilenetv2_tv_x1' # model architecture'
args.model = None #if mdoel is crated externaly
args.dataset_name = 'imagenet_classification' # image folder classification
+ args.transforms = None # the transforms itself can be given from outside
args.save_path = None # checkpoints save path
args.phase = 'training' # training/calibration/validation
args.date = None # date to add to save path. if this is None, current date will be added.
- args.workers = 8 # number of data loading workers (default: 8)
+ args.workers = 12 # number of data loading workers (default: 8)
args.logger = None # logger stream to output into
- args.epochs = 90 # number of total epochs to run
- args.warmup_epochs = None # number of epochs to warm up by linearly increasing lr
+ args.epochs = 150 # number of total epochs to run: recommended 100 or 150
+ args.warmup_epochs = 5 # number of epochs to warm up by linearly increasing lr
+ args.warmup_factor = 1e-3 # max lr allowed for the first epoch during warmup (as a factor of initial lr)
args.epoch_size = 0 # fraction of training epoch to use each time. 0 indicates full
args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
args.start_epoch = 0 # manual epoch number to start
args.stop_epoch = None # manual epoch number to stop
- args.batch_size = 256 # mini_batch size (default: 256)
+ args.batch_size = 512 # mini_batch size (default: 256)
args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
args.lr_clips = None # use args.lr itself if it is None
args.lr_calib = 0.05 # lr for bias calibration
args.momentum = 0.9 # momentum
- args.weight_decay = 1e-4 # weight decay (default: 1e-4)
+ args.weight_decay = 4e-5 # weight decay (default: 1e-4)
args.bias_decay = None # bias decay (default: 0.0)
args.shuffle = True # shuffle or not
args.dist_url = 'tcp://224.66.41.62:23456' # url used to set up distributed training
args.dist_backend = 'gloo' # distributed backend
- args.optimizer = 'sgd' # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
- args.scheduler = 'step' # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
+ args.optimizer = 'sgd' # optimizer algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
+ args.scheduler = 'cosine' # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
args.milestones = (30, 60, 90) # epochs at which learning rate is divided
args.multistep_gamma = 0.1 # multi step gamma (default: 0.1)
args.polystep_power = 1.0 # poly step gamma (default: 1.0)
- args.step_size = 1, # step size for exp lr decay
+ args.step_size = 1 # step size for exp lr decay
args.beta = 0.999 # beta parameter for adam
args.pretrained = None # path to pre_trained model
args.img_resize = 256 # image resize
args.img_crop = 224 # image crop
- args.rand_scale = (0.08,1.0) # random scale range for training
+ args.rand_scale = (0.2,1.0) # random scale range for training
args.data_augument = 'inception' # data augumentation method, choices=['inception','resize','adaptive_resize']
args.count_flops = True # count flops and report
args.histogram_range = True # histogram range for calibration
args.bias_calibration = True # apply bias correction during quantized inference calibration
args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
+ args.constrain_bias = None # constrain bias according to the constraints of convolution engine
args.freeze_bn = False # freeze the statistics of bn
args.save_mod_files = False # saves modified files after last commit. Also stores commit id.
- args.opset_version = 9 # onnx opset_version
+ args.opset_version = 11 # onnx opset_version
return args
if 'training' in args.phase:
model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
- bitwidth_activations=args.bitwidth_activations,
+ bitwidth_activations=args.bitwidth_activations, constrain_bias=args.constrain_bias,
dummy_input=dummy_input)
elif 'calibration' in args.phase:
model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
- histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
+ histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
+ bias_calibration=args.bias_calibration, dummy_input=dummy_input,
lr_calib=args.lr_calib)
elif 'validation' in args.phase:
# Note: bias_calibration is not used in test
count_flops(args, model)
#################################################
- if args.save_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)):
+ if args.save_onnx:
write_onnx_model(args, get_model_orig(model), save_path)
#
close(args)
return
+ grad_scaler = torch.cuda.amp.GradScaler() if args.model_config.enable_fp16 else None
+
for epoch in range(args.start_epoch, args.stop_epoch):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
# train for one epoch
- train(args, train_loader, model, criterion, optimizer, epoch)
+ train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler)
# evaluate on validation set
prec1 = validate(args, val_loader, model, criterion, epoch)
def close(args):
if args.logger is not None:
+ args.logger.close()
del args.logger
args.logger = None
#
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
-def train(args, train_loader, model, criterion, optimizer, epoch):
+def train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler):
# actual training code
batch_time = AverageMeter()
data_time = AverageMeter()
# switch to train mode
model.train()
- if args.freeze_bn:
+
+ # freeze bn and range after some epochs during quantization
+ if args.freeze_bn or (args.quantize and epoch > 2 and epoch >= ((args.epochs//2)-1)):
+ xnn.utils.print_once('Freezing BN for subsequent epochs')
xnn.utils.freeze_bn(model)
#
+ if (args.quantize and epoch > 4 and epoch >= ((args.epochs//2)+1)):
+ xnn.utils.print_once('Freezing ranges for subsequent epochs')
+ xnn.layers.freeze_quant_range(model)
+ #
num_iters = len(train_loader)
progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
top5.update(prec5[0], input_size[0])
if 'training' in args.phase:
- # zero gradients so that we can accumulate gradients
- if (iteration % args.iter_size) == 0:
- optimizer.zero_grad()
-
- loss.backward()
+ if args.model_config.enable_fp16:
+ grad_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+ #
if ((iteration+1) % args.iter_size) == 0:
- optimizer.step()
+ if args.model_config.enable_fp16:
+ grad_scaler.step(optimizer)
+ grad_scaler.update()
+ else:
+ optimizer.step()
+ #
+ # setting grad=None is a faster alternative instead of optimizer.zero_grad()
+ xnn.utils.clear_grad(model)
+ #
#
# measure elapsed time
def adjust_learning_rate(args, optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
- cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
+ cur_lr = args.lr
- if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
- cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
+ if (args.warmup_epochs is not None) and (epoch <= args.warmup_epochs):
+ cur_lr = epoch * args.lr / args.warmup_epochs
+ if epoch == 0 and args.warmup_factor is not None:
+ cur_lr = max(cur_lr, args.lr * args.warmup_factor)
+ #
elif args.scheduler == 'poly':
epoch_frac = (args.epochs - epoch) / args.epochs
epoch_frac = max(epoch_frac, 0)
cur_lr = args.lr * (epoch_frac ** args.polystep_power)
- for param_group in optimizer.param_groups:
- param_group['lr'] = cur_lr
- #
elif args.scheduler == 'step': # step
num_milestones = 0
for m in args.milestones:
cur_lr = args.lr
else:
lr_min = 0
- cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0 + lr_min
+ cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0 + lr_min
#
else:
ValueError('Unknown scheduler {}'.format(args.scheduler))
res = []
for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
always_use_val_transform = (args.rand_scale[0] == 0)
train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
val_transform = get_validation_transform(args)
- return train_transform, val_transform
+ return (train_transform, val_transform)
def get_data_loaders(args):
- train_transform, val_transform = get_transforms(args)
+ train_transform, val_transform = get_transforms(args) if args.transforms is None else (args.transforms[0], args.transforms[1])
train_dataset, val_dataset = xvision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))