51d619264753cf4994750a8ca769a0de245a4c08
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / scripts / train_classification_main.py
1 #!/usr/bin/env python
3 import sys
4 import os
5 import cv2
6 import argparse
7 import datetime
8 import numpy as np
10 ################################
11 from pytorch_jacinto_ai.xnn.utils import str2bool
12 parser = argparse.ArgumentParser()
13 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
14 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
15 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
16 parser.add_argument('--strides', type=int, nargs='*', default=None, help='strides in the model')
17 parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
18 parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
19 parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
20 parser.add_argument('--model_name', type=str, default=None, help='model name')
21 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
22 parser.add_argument('--data_path', type=str, default=None, help='data path')
23 parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
24 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
25 parser.add_argument('--warmup_epochs', type=int, default=None, help='number of epochs for the learning rate to increase and reach base value')
26 parser.add_argument('--milestones', type=int, nargs='*', default=None, help='change lr at these milestones')
27 parser.add_argument('--img_resize', type=int, default=None, help='images will be first resized to this size during training and validation')
28 parser.add_argument('--rand_scale', type=float, nargs=2, default=None, help='during training (only) fraction of the image to crop (this will then be resized to img_crop)')
29 parser.add_argument('--img_crop', type=int, default=None, help='the cropped portion (validation), cropped pertion will be resized to this size (training)')
30 parser.add_argument('--pretrained', type=str, default=None, help='pretrained model')
31 parser.add_argument('--resume', type=str, default=None, help='resume an unfinished training from this model')
32 parser.add_argument('--phase', type=str, default=None, help='training/calibration/validation')
33 parser.add_argument('--evaluate_start', type=str2bool, default=None, help='Whether to run validation before the training')
34 parser.add_argument('--workers', type=int, default=None, help='number of workers for dataloading')
35 #
36 parser.add_argument('--quantize', type=str2bool, default=None, help='Quantize the model')
37 parser.add_argument('--histogram_range', type=str2bool, default=None, help='run only evaluation and no training')
38 parser.add_argument('--per_channel_q', type=str2bool, default=None, help='run only evaluation and no training')
39 parser.add_argument('--bias_calibration', type=str2bool, default=None, help='run only evaluation and no training')
40 parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth for weight quantization')
41 parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
42 #
43 parser.add_argument('--freeze_bn', type=str2bool, default=None, help='freeze the bn stats or not')
45 cmds = parser.parse_args()
47 ################################
48 # taken care first, since this has to be done before importing pytorch
49 if 'gpus' in vars(cmds):
50     value = getattr(cmds, 'gpus')
51     if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
52         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
53     #
54 #
56 # to avoid hangs in data loader with multi threads
57 # this was observed after using cv2 image processing functions
58 # https://github.com/pytorch/pytorch/issues/1355
59 cv2.setNumThreads(0)
61 ################################
62 #import of torch should be after CUDA_VISIBLE_DEVICES for it to take effect
63 import torch
64 from pytorch_jacinto_ai.engine import train_classification
66 #Create the parse and set default arguments
67 args = train_classification.get_config()
69 ################################
70 date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
72 ################################
73 #Set arguments
74 args.model_name = 'mobilenetv2_tv_x1'        # 'resnet50_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1'
76 args.dataset_name = 'image_folder_classification' # 'image_folder_classification', 'imagenet_classification', 'cifar10_classification', 'cifar100_classification'
78 #args.save_path = './data/checkpoints'
80 args.data_path = f'./data/datasets/{args.dataset_name}'
82 # args.pretrained = 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
83                     #'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
84                     #'./data/modelzoo/pretrained/pytorch/imagenet_classification/torchvision/resnet50-19c8e357.pth'
85                     #'./data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar'
87 #args.resume = './data/checkpoints/imagenet_classification/2019-01-24-09:34:45_imagenet_classification_shufflemobilenetv2_resize256_crop224_(epochs29)/checkpoint.pth.tar'
89 #args.start_epoch = 100
91 args.model_config.input_channels = 3
92 args.model_config.output_type = 'classification'
93 args.model_config.output_channels = None
94 args.model_config.strides = None #(2,2,2,2,2)
96 args.img_resize = 256
97 args.img_crop = 224
98 args.solver = 'sgd'                     #'sgd' #'adam'
99 args.epochs = 150                       #150 #120
100 args.start_epoch = 0                    #0
101 args.epoch_size = 0                     #0 #0.5
102 args.scheduler = 'cosine'               #'poly' #'step' #'exponential' #'cosine'
103 args.multistep_gamma = 0.1              #0.1 #0.94 #0.98 #for step and exponential schedulers
104 args.milestones = [30, 60, 90]          #only for step scheduler
105 args.polystep_power = 1.0               #only for poly scheduler
106 args.step_size = 1                      #only for exp scheduler
107 args.iter_size = 1                      #2
108 args.warmup_epochs = 5                  #5
110 args.lr = 0.1                           #0.2 #0.1 #0.045 #0.01
111 args.batch_size = 512 #256 #128 #512 #1024
112 args.weight_decay = 4e-5                #4e-5 #1e-4
113 args.rand_scale = (0.2, 1.0)            #(0.08, 1.0)
115 #args.print_model = True
116 #args.generate_onnx = False
117 #args.run_soon = True #default(True) #Set to false if only cfs files/onnx  models needed but no training is required
118 #args.phase = 'validation'
119 # args.evaluate_start = False
121 # args.quantize = True
122 # args.phase = 'training'
124 # defining date from outside can help to write multiple pahses into the same folder
125 args.date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
127 ################################
128 for key in vars(cmds):
129     if key == 'gpus': # already taken care above, since this has to be done before importing pytorch
130         pass
131     elif key == 'strides': # strides is in model_config
132         value = getattr(cmds, key)
133         if value != 'None' and value is not None:
134             setattr(args.model_config, key, value)
135         #
136     elif hasattr(args, key):
137         value = getattr(cmds, key)
138         if value != 'None' and value is not None:
139             setattr(args, key, value)
140     else:
141         assert False, f'invalid argument {key}'
144 ################################
145 # these dependent on the dataset chosen
146 args.model_config.num_classes = (100 if 'cifar100' in args.dataset_name else (10  if 'cifar10' in args.dataset_name else 1000))
150 ################################
151 # Run the training
152 train_classification.main(args)
154 ################################
155 # In addition run a quantized calibration, starting from the trained model
156 if 'training' in args.phase and (not args.quantize):
157     save_path = train_classification.get_save_path(args)
158     args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
159     args.phase = 'training_quantize'
160     args.quantize = True
161     args.lr = 1e-5
162     args.epochs = 25
163     # quantized training will use only one GPU in the engine - so reduce the batch_size
164     num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(','))
165     args.batch_size = args.batch_size//num_gpus
166     train_classification.main(args)
169 ################################
170 # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately
171 if 'training' in args.phase or 'calibration' in args.phase:
172     save_path = train_classification.get_save_path(args)
173     args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
174     args.phase = 'validation'
175     args.quantize = True
176     train_classification.main(args)