]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/engine/train_classification.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_classification.py
index bc208a4f6e4d0d88b69ed9cd589b898713d2d7bf..9bfde8e6b7bcfc440167b5d69862096917671d0d 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import shutil
 import time
@@ -32,31 +60,37 @@ def get_config():
     args = xnn.utils.ConfigNode()
     args.model_config = xnn.utils.ConfigNode()
     args.dataset_config = xnn.utils.ConfigNode()
+    args.model_config.input_channels = 3                # num input channels
+    args.model_config.output_type = 'classification'
+    args.model_config.output_channels = None
+    args.model_config.strides = None                    # (2,2,2,2,2)
     args.model_config.num_tiles_x = int(1)
     args.model_config.num_tiles_y = int(1)
     args.model_config.en_make_divisible_by8 = True
-    args.model_config.input_channels = 3                # num input channels
+    args.model_config.enable_fp16 = False               # FP16 half precision mode
 
     args.input_channel_reverse = False                  # rgb to bgr
     args.data_path = './data/datasets/ilsvrc'           # path to dataset
     args.model_name = 'mobilenetv2_tv_x1'     # model architecture'
     args.model = None                                   #if mdoel is crated externaly 
     args.dataset_name = 'imagenet_classification'       # image folder classification
+    args.transforms = None                              # the transforms itself can be given from outside
     args.save_path = None                               # checkpoints save path
     args.phase = 'training'                             # training/calibration/validation
     args.date = None                                    # date to add to save path. if this is None, current date will be added.
 
-    args.workers =                                    # number of data loading workers (default: 8)
+    args.workers = 12                                   # number of data loading workers (default: 8)
     args.logger = None                                  # logger stream to output into
 
-    args.epochs = 90                                    # number of total epochs to run
-    args.warmup_epochs = None                           # number of epochs to warm up by linearly increasing lr
+    args.epochs = 150                                   # number of total epochs to run: recommended 100 or 150
+    args.warmup_epochs = 5                              # number of epochs to warm up by linearly increasing lr
+    args.warmup_factor = 1e-3                           # max lr allowed for the first epoch during warmup (as a factor of initial lr)
 
     args.epoch_size = 0                                 # fraction of training epoch to use each time. 0 indicates full
     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
     args.start_epoch = 0                                # manual epoch number to start
     args.stop_epoch = None                              # manual epoch number to stop
-    args.batch_size = 256                               # mini_batch size (default: 256)
+    args.batch_size = 512                               # mini_batch size (default: 256)
     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
 
@@ -64,7 +98,7 @@ def get_config():
     args.lr_clips = None                                # use args.lr itself if it is None
     args.lr_calib = 0.05                                # lr for bias calibration
     args.momentum = 0.9                                 # momentum
-    args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
+    args.weight_decay = 4e-5                            # weight decay (default: 1e-4)
     args.bias_decay = None                              # bias decay (default: 0.0)
 
     args.shuffle = True                                 # shuffle or not
@@ -78,18 +112,18 @@ def get_config():
     args.dist_url = 'tcp://224.66.41.62:23456'          # url used to set up distributed training
     args.dist_backend = 'gloo'                          # distributed backend
 
-    args.optimizer = 'sgd'                              # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
-    args.scheduler = 'step'                             # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
+    args.optimizer = 'sgd'                              # optimizer algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
+    args.scheduler = 'cosine'                           # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
     args.milestones = (30, 60, 90)                      # epochs at which learning rate is divided
     args.multistep_gamma = 0.1                          # multi step gamma (default: 0.1)
     args.polystep_power = 1.0                           # poly step gamma (default: 1.0)
-    args.step_size = 1,                                 # step size for exp lr decay
+    args.step_size = 1                                  # step size for exp lr decay
 
     args.beta = 0.999                                   # beta parameter for adam
     args.pretrained = None                              # path to pre_trained model
     args.img_resize = 256                               # image resize
     args.img_crop = 224                                 # image crop
-    args.rand_scale = (0.08,1.0)                        # random scale range for training
+    args.rand_scale = (0.2,1.0)                         # random scale range for training
     args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
     args.count_flops = True                             # count flops and report
 
@@ -110,11 +144,12 @@ def get_config():
     args.histogram_range = True                         # histogram range for calibration
     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+    args.constrain_bias = None                          # constrain bias according to the constraints of convolution engine
 
     args.freeze_bn = False                              # freeze the statistics of bn
     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
 
-    args.opset_version =                              # onnx opset_version
+    args.opset_version = 11                             # onnx opset_version
     return args
 
 
@@ -246,12 +281,13 @@ def main(args):
         if 'training' in args.phase:
             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
-                        bitwidth_activations=args.bitwidth_activations,
+                        bitwidth_activations=args.bitwidth_activations, constrain_bias=args.constrain_bias,
                         dummy_input=dummy_input)
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
+                        histogram_range=args.histogram_range,  constrain_bias=args.constrain_bias,
+                        bias_calibration=args.bias_calibration, dummy_input=dummy_input,
                         lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not used in test
@@ -278,7 +314,7 @@ def main(args):
         count_flops(args, model)
 
     #################################################
-    if args.save_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)):
+    if args.save_onnx:
         write_onnx_model(args, get_model_orig(model), save_path)
     #
 
@@ -374,12 +410,14 @@ def main(args):
         close(args)
         return
 
+    grad_scaler = torch.cuda.amp.GradScaler() if args.model_config.enable_fp16 else None
+
     for epoch in range(args.start_epoch, args.stop_epoch):
         if args.distributed:
             train_loader.sampler.set_epoch(epoch)
 
         # train for one epoch
-        train(args, train_loader, model, criterion, optimizer, epoch)
+        train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler)
 
         # evaluate on validation set
         prec1 = validate(args, val_loader, model, criterion, epoch)
@@ -414,6 +452,7 @@ def is_valid_phase(phase):
 
 def close(args):
     if args.logger is not None:
+        args.logger.close()
         del args.logger
         args.logger = None
     #
@@ -467,7 +506,7 @@ def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
         onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
 
 
-def train(args, train_loader, model, criterion, optimizer, epoch):
+def train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler):
     # actual training code
     batch_time = AverageMeter()
     data_time = AverageMeter()
@@ -477,9 +516,16 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
 
     # switch to train mode
     model.train()
-    if args.freeze_bn:
+
+    # freeze bn and range after some epochs during quantization
+    if args.freeze_bn or (args.quantize and epoch > 2 and epoch >= ((args.epochs//2)-1)):
+        xnn.utils.print_once('Freezing BN for subsequent epochs')
         xnn.utils.freeze_bn(model)
     #
+    if (args.quantize and epoch > 4 and epoch >= ((args.epochs//2)+1)):
+        xnn.utils.print_once('Freezing ranges for subsequent epochs')
+        xnn.layers.freeze_quant_range(model)
+    #
 
     num_iters = len(train_loader)
     progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
@@ -525,14 +571,22 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
         top5.update(prec5[0], input_size[0])
 
         if 'training' in args.phase:
-            # zero gradients so that we can accumulate gradients
-            if (iteration % args.iter_size) == 0:
-                optimizer.zero_grad()
-
-            loss.backward()
+            if args.model_config.enable_fp16:
+                grad_scaler.scale(loss).backward()
+            else:
+                loss.backward()
+            #
 
             if ((iteration+1) % args.iter_size) == 0:
-                optimizer.step()
+                if args.model_config.enable_fp16:
+                    grad_scaler.step(optimizer)
+                    grad_scaler.update()
+                else:
+                    optimizer.step()
+                #
+                # setting grad=None is a faster alternative instead of optimizer.zero_grad()
+                xnn.utils.clear_grad(model)
+            #
         #
 
         # measure elapsed time
@@ -677,17 +731,17 @@ class AverageMeter(object):
 
 def adjust_learning_rate(args, optimizer, epoch):
     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
-    cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
+    cur_lr = args.lr
 
-    if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
-        cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
+    if (args.warmup_epochs is not None) and (epoch <= args.warmup_epochs):
+        cur_lr = epoch * args.lr / args.warmup_epochs
+        if epoch == 0 and args.warmup_factor is not None:
+            cur_lr = max(cur_lr, args.lr * args.warmup_factor)
+        #
     elif args.scheduler == 'poly':
         epoch_frac = (args.epochs - epoch) / args.epochs
         epoch_frac = max(epoch_frac, 0)
         cur_lr = args.lr * (epoch_frac ** args.polystep_power)
-        for param_group in optimizer.param_groups:
-            param_group['lr'] = cur_lr
-        #
     elif args.scheduler == 'step':                                            # step
         num_milestones = 0
         for m in args.milestones:
@@ -701,7 +755,7 @@ def adjust_learning_rate(args, optimizer, epoch):
             cur_lr = args.lr
         else:
             lr_min = 0
-            cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0  + lr_min
+            cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0 + lr_min
         #
     else:
         ValueError('Unknown scheduler {}'.format(args.scheduler))
@@ -724,7 +778,7 @@ def accuracy(output, target, topk=(1,)):
 
         res = []
         for k in topk:
-            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
             res.append(correct_k.mul_(100.0 / batch_size))
         return res
 
@@ -792,10 +846,10 @@ def get_transforms(args):
     always_use_val_transform = (args.rand_scale[0] == 0)
     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
     val_transform = get_validation_transform(args)
-    return train_transform, val_transform
+    return (train_transform, val_transform)
 
 def get_data_loaders(args):
-    train_transform, val_transform = get_transforms(args)
+    train_transform, val_transform = get_transforms(args) if args.transforms is None else (args.transforms[0], args.transforms[1])
 
     train_dataset, val_dataset = xvision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))