[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_classification.py
1 import os
2 import shutil
3 import time
5 import random
6 import numpy as np
7 from colorama import Fore
8 import math
9 import progiter
10 import warnings
12 import torch
13 import torch.nn.parallel
14 import torch.backends.cudnn as cudnn
15 import torch.distributed as dist
16 import torch.optim
17 import torch.utils.data
18 import torch.utils.data.distributed
20 import sys
21 import datetime
23 import onnx
24 from onnx import shape_inference
26 from .. import xnn
27 from .. import vision
30 #################################################
31 def get_config():
32 args = xnn.utils.ConfigNode()
33 args.model_config = xnn.utils.ConfigNode()
34 args.dataset_config = xnn.utils.ConfigNode()
35 args.model_config.num_tiles_x = int(1)
36 args.model_config.num_tiles_y = int(1)
37 args.model_config.en_make_divisible_by8 = True
39 args.model_config.input_channels = 3 # num input channels
41 args.data_path = './data/datasets/ilsvrc' # path to dataset
42 args.model_name = 'mobilenetv2_tv_x1' # model architecture'
43 args.model = None #if mdoel is crated externaly
44 args.dataset_name = 'imagenet_classification' # image folder classification
45 args.save_path = None # checkpoints save path
46 args.phase = 'training' # training/calibration/validation
47 args.date = None # date to add to save path. if this is None, current date will be added.
49 args.workers = 8 # number of data loading workers (default: 8)
50 args.logger = None # logger stream to output into
52 args.epochs = 90 # number of total epochs to run
53 args.warmup_epochs = None # number of epochs to warm up by linearly increasing lr
55 args.epoch_size = 0 # fraction of training epoch to use each time. 0 indicates full
56 args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
57 args.start_epoch = 0 # manual epoch number to start
58 args.stop_epoch = None # manual epoch number to stop
59 args.batch_size = 256 # mini_batch size (default: 256)
60 args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
61 args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
63 args.lr = 0.1 # initial learning rate
64 args.lr_clips = None # use args.lr itself if it is None
65 args.lr_calib = 0.05 # lr for bias calibration
66 args.momentum = 0.9 # momentum
67 args.weight_decay = 1e-4 # weight decay (default: 1e-4)
68 args.bias_decay = None # bias decay (default: 0.0)
70 args.shuffle = True # shuffle or not
71 args.shuffle_val = True # shuffle val dataset or not
73 args.rand_seed = 1 # random seed
74 args.print_freq = 100 # print frequency (default: 100)
75 args.resume = None # path to latest checkpoint (default: none)
76 args.evaluate_start = True # evaluate right at the begining of training or not
77 args.world_size = 1 # number of distributed processes
78 args.dist_url = 'tcp://224.66.41.62:23456' # url used to set up distributed training
79 args.dist_backend = 'gloo' # distributed backend
81 args.optimizer = 'sgd' # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
82 args.scheduler = 'step' # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
83 args.milestones = (30, 60, 90) # epochs at which learning rate is divided
84 args.multistep_gamma = 0.1 # multi step gamma (default: 0.1)
85 args.polystep_power = 1.0 # poly step gamma (default: 1.0)
86 args.step_size = 1, # step size for exp lr decay
88 args.beta = 0.999 # beta parameter for adam
89 args.pretrained = None # path to pre_trained model
90 args.img_resize = 256 # image resize
91 args.img_crop = 224 # image crop
92 args.rand_scale = (0.08,1.0) # random scale range for training
93 args.data_augument = 'inception' # data augumentation method, choices=['inception','resize','adaptive_resize']
94 args.count_flops = True # count flops and report
96 args.save_onnx = True # apply quantized inference or not
97 args.print_model = False # print the model to text
98 args.run_soon = True # Set to false if only cfs files/onnx modelsneeded but no training
100 args.multi_color_modes = None # input modes created with multi color transform
101 args.image_mean = (123.675, 116.28, 103.53) # image mean for input image normalization')
102 args.image_scale = (0.017125, 0.017507, 0.017429) # image scaling/mult for input iamge normalization')
104 args.parallel_model = True # Usedata parallel for model
106 args.quantize = False # apply quantized inference or not
107 #args.model_surgery = None # replace activations with PAct2 activation module. Helpful in quantized training.
108 args.bitwidth_weights = 8 # bitwidth for weights
109 args.bitwidth_activations = 8 # bitwidth for activations
110 args.histogram_range = True # histogram range for calibration
111 args.bias_calibration = True # apply bias correction during quantized inference calibration
112 args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
114 args.freeze_bn = False # freeze the statistics of bn
115 args.save_mod_files = False # saves modified files after last commit. Also stores commit id.
117 args.opset_version = 9 # onnx opset_version
118 return args
121 #################################################
122 cudnn.benchmark = True
123 #cudnn.enabled = False
127 def main(args):
128 assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
129 assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
130 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
132 if (args.phase == 'validation' and args.bias_calibration):
133 args.bias_calibration = False
134 warnings.warn('switching off bias calibration in validation')
135 #
137 #################################################
138 if args.save_path is None:
139 save_path = get_save_path(args)
140 else:
141 save_path = args.save_path
142 #
144 args.best_prec1 = -1
146 # resume has higher priority
147 args.pretrained = None if (args.resume is not None) else args.pretrained
149 if not os.path.exists(save_path):
150 os.makedirs(save_path)
152 if args.save_mod_files:
153 #store all the files after the last commit.
154 mod_files_path = save_path+'/mod_files'
155 os.makedirs(mod_files_path)
157 cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
158 print("cmd:", cmd)
159 os.system(cmd)
161 #stoe last commit id.
162 cmd = "git log -n 1 >> {}".format(mod_files_path + '/commit_id.txt')
163 print("cmd:", cmd)
164 os.system(cmd)
165 #################################################
166 if args.logger is None:
167 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
168 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
170 ################################
171 args.pretrained = None if (args.pretrained == 'None') else args.pretrained
172 args.num_inputs = len(args.multi_color_modes) if (args.multi_color_modes is not None) else 1
174 if args.iter_size != 1 and args.total_batch_size is not None:
175 warnings.warn("only one of --iter_size or --total_batch_size must be set")
176 #
177 if args.total_batch_size is not None:
178 args.iter_size = args.total_batch_size//args.batch_size
179 else:
180 args.total_batch_size = args.batch_size*args.iter_size
182 args.stop_epoch = args.stop_epoch if args.stop_epoch else args.epochs
184 args.distributed = args.world_size > 1
186 if args.distributed:
187 dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
188 world_size=args.world_size)
190 #################################################
191 # global settings. rand seeds for repeatability
192 random.seed(args.rand_seed)
193 np.random.seed(args.rand_seed)
194 torch.manual_seed(args.rand_seed)
195 torch.backends.cudnn.deterministic = True
196 torch.backends.cudnn.benchmark = True
197 # torch.autograd.set_detect_anomaly(True)
199 ################################
200 # print everything for log
201 # reset character color, in case it is different
202 print('{}'.format(Fore.RESET))
203 print("=> args: ", args)
204 print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
205 print("=> resize resolution: {}".format(args.img_resize))
206 print("=> crop resolution : {}".format(args.img_crop))
207 sys.stdout.flush()
209 #################################################
210 pretrained_data = None
211 model_surgery_quantize = False
212 if args.pretrained and args.pretrained != "None":
213 if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
214 pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
215 else:
216 pretrained_file = args.pretrained
217 #
218 print(f'=> using pre-trained weights from: {args.pretrained}')
219 pretrained_data = torch.load(pretrained_file)
220 model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
221 #
223 #################################################
224 # create model
225 print("=> creating model '{}'".format(args.model_name))
227 model = vision.models.classification.__dict__[args.model_name](args.model_config) if args.model == None else args.model
229 # check if we got the model as well as parameters to change the names in pretrained
230 model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
232 #################################################
233 if args.quantize:
234 # dummy input is used by quantized models to analyze graph
235 is_cuda = next(model.parameters()).is_cuda
236 dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
237 #
238 if 'training' in args.phase:
239 model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
240 histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
241 bitwidth_activations=args.bitwidth_activations,
242 dummy_input=dummy_input)
243 elif 'calibration' in args.phase:
244 model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
245 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
246 histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
247 lr_calib=args.lr_calib)
248 elif 'validation' in args.phase:
249 # Note: bias_calibration is not used in test
250 model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
251 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
252 histogram_range=args.histogram_range, dummy_input=dummy_input,
253 model_surgery_quantize=model_surgery_quantize)
254 else:
255 assert False, f'invalid phase {args.phase}'
256 #
258 # load pretrained
259 if pretrained_data is not None:
260 xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
261 #
263 #################################################
264 if args.count_flops:
265 count_flops(args, model)
267 #################################################
268 if args.save_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)):
269 write_onnx_model(args, get_model_orig(model), save_path)
270 #
272 #################################################
273 if args.print_model:
274 print(model)
275 else:
276 args.logger.debug(str(model))
278 #################################################
279 if (not args.run_soon):
280 print("Training not needed for now")
281 close(args)
282 exit()
284 #################################################
285 # multi gpu mode is not working for quantized model
286 if args.parallel_model and (not args.quantize):
287 if args.distributed:
288 model = torch.nn.parallel.DistributedDataParallel(model)
289 else:
290 model = torch.nn.DataParallel(model)
291 #
293 #################################################
294 model = model.cuda()
296 #################################################
297 # define loss function (criterion) and optimizer
298 criterion = torch.nn.CrossEntropyLoss().cuda()
300 model_module = model.module if hasattr(model, 'module') else model
301 if args.lr_clips is not None:
302 learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
303 clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
304 clips_params = [p for n,p in model_module.named_parameters() if 'clips' in n]
305 other_params = [p for n,p in model_module.named_parameters() if 'clips' not in n]
306 param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
307 {'params': other_params, 'weight_decay': args.weight_decay}]
308 else:
309 param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
310 #
312 print("=> args: ", args)
313 print("=> optimizer type : {}".format(args.optimizer))
314 print("=> learning rate : {}".format(args.lr))
315 print("=> resize resolution: {}".format(args.img_resize))
316 print("=> crop resolution : {}".format(args.img_crop))
317 print("=> batch size : {}".format(args.batch_size))
318 print("=> total batch size : {}".format(args.total_batch_size))
319 print("=> epoch size : {}".format(args.epoch_size))
320 print("=> data augument : {}".format(args.data_augument))
321 print("=> epochs : {}".format(args.epochs))
322 if args.scheduler == 'step':
323 print("=> milestones : {}".format(args.milestones))
325 learning_rate = args.lr if ('training'in args.phase) else 0.0
326 if args.optimizer == 'adam':
327 optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
328 elif args.optimizer == 'sgd':
329 optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
330 elif args.optimizer == 'sgd_nesterov':
331 optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum, nesterov=True)
332 elif args.optimizer == 'rmsprop':
333 optimizer = torch.optim.RMSprop(param_groups, learning_rate, momentum=args.momentum)
334 else:
335 raise ValueError('Unknown optimizer type{}'.format(args.optimizer))
337 # optionally resume from a checkpoint
338 if args.resume:
339 if os.path.isfile(args.resume):
340 print("=> resuming from checkpoint '{}'".format(args.resume))
341 checkpoint = torch.load(args.resume)
342 if args.start_epoch == 0:
343 args.start_epoch = checkpoint['epoch'] + 1
345 args.best_prec1 = checkpoint['best_prec1']
346 model = xnn.utils.load_weights(model, checkpoint)
347 optimizer.load_state_dict(checkpoint['optimizer'])
348 print("=> resuming from checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
349 else:
350 print("=> no checkpoint found at '{}'".format(args.resume))
352 train_loader, val_loader = get_data_loaders(args)
354 args.cur_lr = adjust_learning_rate(args, optimizer, args.start_epoch)
356 if args.evaluate_start or args.phase=='validation':
357 validate(args, val_loader, model, criterion, args.start_epoch)
359 if args.phase == 'validation':
360 close(args)
361 return
363 for epoch in range(args.start_epoch, args.stop_epoch):
364 if args.distributed:
365 train_loader.sampler.set_epoch(epoch)
367 # train for one epoch
368 train(args, train_loader, model, criterion, optimizer, epoch)
370 # evaluate on validation set
371 prec1 = validate(args, val_loader, model, criterion, epoch)
373 # remember best prec@1 and save checkpoint
374 is_best = prec1 > args.best_prec1
375 args.best_prec1 = max(prec1, args.best_prec1)
377 model_orig = get_model_orig(model)
379 save_dict = {'epoch': epoch, 'arch': args.model_name, 'state_dict': model_orig.state_dict(),
380 'best_prec1': args.best_prec1, 'optimizer' : optimizer.state_dict(),
381 'quantize' : args.quantize}
383 save_checkpoint(args, save_path, model_orig, save_dict, is_best)
384 #
386 # for n, m in model.named_modules():
387 # if hasattr(m, 'num_batches_tracked'):
388 # print(f'name={n}, num_batches_tracked={m.num_batches_tracked}')
389 # #
391 # close and cleanup
392 close(args)
393 #
395 ###################################################################
396 def is_valid_phase(phase):
397 phases = ('training', 'calibration', 'validation')
398 return any(p in phase for p in phases)
401 def close(args):
402 if args.logger is not None:
403 del args.logger
404 args.logger = None
405 #
406 args.best_prec1 = -1
409 def get_save_path(args, phase=None):
410 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
411 save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
412 save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
413 phase = phase if (phase is not None) else args.phase
414 save_path = os.path.join(save_path, phase)
415 return save_path
418 def get_model_orig(model):
419 is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
420 model_orig = (model.module if is_parallel_model else model)
421 model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
422 return model_orig
425 def create_rand_inputs(args, is_cuda):
426 dummy_input = torch.rand((1, args.model_config.input_channels, args.img_crop*args.model_config.num_tiles_y,
427 args.img_crop*args.model_config.num_tiles_x))
428 dummy_input = dummy_input.cuda() if is_cuda else dummy_input
429 return dummy_input
432 def count_flops(args, model):
433 is_cuda = next(model.parameters()).is_cuda
434 dummy_input = create_rand_inputs(args, is_cuda)
435 #
436 model.eval()
437 flops = xnn.utils.forward_count_flops(model, dummy_input)
438 gflops = flops/1e9
439 print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
442 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
443 is_cuda = next(model.parameters()).is_cuda
444 dummy_input = create_rand_inputs(args, is_cuda)
445 #
446 model.eval()
447 torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False,
448 do_constant_folding=True, opset_version=args.opset_version)
450 #to see tensor shape in ONNX graph. Works only upto ver 8
451 if args.opset_version <= 8:
452 path = os.path.join(save_path,name)
453 onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
456 def train(args, train_loader, model, criterion, optimizer, epoch):
457 # actual training code
458 batch_time = AverageMeter()
459 data_time = AverageMeter()
460 losses = AverageMeter()
461 top1 = AverageMeter()
462 top5 = AverageMeter()
464 # switch to train mode
465 model.train()
466 if args.freeze_bn:
467 xnn.utils.freeze_bn(model)
468 #
470 num_iters = len(train_loader)
471 progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
472 args.cur_lr = adjust_learning_rate(args, optimizer, epoch)
474 end = time.time()
475 last_update_iter = -1
477 progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
478 print('{}'.format(progressbar_color), end='')
480 for iteration, (input, target) in enumerate(train_loader):
481 input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
482 input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
483 target = target.cuda(non_blocking=True)
485 data_time.update(time.time() - end)
487 # preprocess to make tiles
488 if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
489 input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
490 #
492 # compute output
493 output = model(input)
495 if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
496 # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class]
497 # e.g. [1,10,4,5] to [1,4,5,10]
498 output = output.permute(0, 2, 3, 1)
499 #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
500 output = torch.reshape(output, (-1, output.shape[-1]))
501 #
503 # compute loss
504 loss = criterion(output, target) / args.iter_size
506 # measure accuracy and record loss
507 prec1, prec5 = accuracy(output, target, topk=(1, 5))
508 losses.update(loss.item(), input_size[0])
509 top1.update(prec1[0], input_size[0])
510 top5.update(prec5[0], input_size[0])
512 if 'training' in args.phase:
513 # zero gradients so that we can accumulate gradients
514 if (iteration % args.iter_size) == 0:
515 optimizer.zero_grad()
517 loss.backward()
519 if ((iteration+1) % args.iter_size) == 0:
520 optimizer.step()
521 #
523 # measure elapsed time
524 batch_time.update(time.time() - end)
525 end = time.time()
526 final_iter = (iteration >= (num_iters-1))
528 if ((iteration % args.print_freq) == 0) or final_iter:
529 epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
530 status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} DataTime={data_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
531 .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)
533 progress_bar.set_description(f'=> {args.phase} ')
534 progress_bar.set_postfix(Epoch='{}'.format(status_str))
535 progress_bar.update(iteration-last_update_iter)
536 last_update_iter = iteration
537 #
538 #
539 progress_bar.close()
541 # to print a new line - do not provide end=''
542 print('{}'.format(Fore.RESET), end='')
544 ##########################
545 if args.quantize:
546 def debug_format(v):
547 return ('{:.3f}'.format(v) if v is not None else 'None')
548 #
549 clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
550 if len(clips_act) > 0:
551 args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
552 args.logger.debug('')
553 #
557 def validate(args, val_loader, model, criterion, epoch):
558 batch_time = AverageMeter()
559 losses = AverageMeter()
560 top1 = AverageMeter()
561 top5 = AverageMeter()
563 # switch to evaluate mode
564 model.eval()
566 num_iters = len(val_loader)
567 progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
568 last_update_iter = -1
570 # change color to green
571 print('{}'.format(Fore.GREEN), end='')
573 with torch.no_grad():
574 end = time.time()
575 for iteration, (input, target) in enumerate(val_loader):
576 input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
577 input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
578 target = target.cuda(non_blocking=True)
580 # preprocess to make tiles
581 if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
582 input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
583 #
585 # compute output
586 output = model(input)
588 if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
589 # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class]
590 # e.g. [1,10,4,5] to [1,4,5,10]
591 output = output.permute(0,2,3,1)
592 #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
593 output = torch.reshape(output, (-1, output.shape[-1]))
594 #
596 loss = criterion(output, target)
598 # measure accuracy and record loss
599 prec1, prec5 = accuracy(output, target, topk=(1, 5))
600 losses.update(loss.item(), input_size[0])
601 top1.update(prec1[0], input_size[0])
602 top5.update(prec5[0], input_size[0])
604 # measure elapsed time
605 batch_time.update(time.time() - end)
606 end = time.time()
607 final_iter = (iteration >= (num_iters-1))
609 if ((iteration % args.print_freq) == 0) or final_iter:
610 epoch_str = '{}/{}'.format(epoch+1,args.epochs)
611 status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
612 .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, loss=losses, top1=top1, top5=top5)
614 prefix = '**' if final_iter else '=>'
615 progress_bar.set_description('{} {}'.format(prefix, 'validation'))
616 progress_bar.set_postfix(Epoch='{}'.format(status_str))
617 progress_bar.update(iteration - last_update_iter)
618 last_update_iter = iteration
619 #
620 #
622 progress_bar.close()
624 # to print a new line - do not provide end=''
625 print('{}'.format(Fore.RESET), end='')
627 return top1.avg
630 def save_checkpoint(args, save_path, model, state, is_best, filename='checkpoint.pth'):
631 filename = os.path.join(save_path, filename)
632 torch.save(state, filename)
633 if is_best:
634 bestname = os.path.join(save_path, 'model_best.pth')
635 shutil.copyfile(filename, bestname)
636 #
637 if args.save_onnx:
638 write_onnx_model(args, model, save_path, name='checkpoint.onnx')
639 if is_best:
640 write_onnx_model(args, model, save_path, name='model_best.onnx')
642 #
644 class AverageMeter(object):
645 """Computes and stores the average and current value"""
646 def __init__(self):
647 self.reset()
649 def reset(self):
650 self.val = 0
651 self.avg = 0
652 self.sum = 0
653 self.count = 0
655 def update(self, val, n=1):
656 self.val = val
657 self.sum += val * n
658 self.count += n
659 self.avg = self.sum / self.count
662 def adjust_learning_rate(args, optimizer, epoch):
663 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
664 cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
666 if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
667 cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
668 elif args.scheduler == 'poly':
669 epoch_frac = (args.epochs - epoch) / args.epochs
670 epoch_frac = max(epoch_frac, 0)
671 cur_lr = args.lr * (epoch_frac ** args.polystep_power)
672 for param_group in optimizer.param_groups:
673 param_group['lr'] = cur_lr
674 #
675 elif args.scheduler == 'step': # step
676 num_milestones = 0
677 for m in args.milestones:
678 num_milestones += (1 if epoch >= m else 0)
679 #
680 cur_lr = args.lr * (args.multistep_gamma ** num_milestones)
681 elif args.scheduler == 'exponential': # exponential
682 cur_lr = args.lr * (args.multistep_gamma ** (epoch//args.step_size))
683 elif args.scheduler == 'cosine': # cosine
684 if epoch == 0:
685 cur_lr = args.lr
686 else:
687 lr_min = 0
688 cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0 + lr_min
689 #
690 else:
691 ValueError('Unknown scheduler {}'.format(args.scheduler))
692 #
693 for param_group in optimizer.param_groups:
694 param_group['lr'] = cur_lr
695 #
696 return cur_lr
699 def accuracy(output, target, topk=(1,)):
700 """Computes the precision@k for the specified values of k"""
701 with torch.no_grad():
702 maxk = max(topk)
703 batch_size = target.size(0)
705 _, pred = output.topk(maxk, 1, True, True)
706 pred = pred.t()
707 correct = pred.eq(target.view(1, -1).expand_as(pred))
709 res = []
710 for k in topk:
711 correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
712 res.append(correct_k.mul_(100.0 / batch_size))
713 return res
716 def get_dataset_sampler(dataset_object, epoch_size, balanced_sampler=False):
717 num_samples = len(dataset_object)
718 epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
719 print('=> creating a random sampler as epoch_size is specified')
720 if balanced_sampler:
721 # going through the dataset this way may take too much time
722 progress_bar = progiter.ProgIter(np.arange(num_samples), chunksize=1, \
723 desc='=> reading data to create a balanced data sampler : ')
724 sample_classes = [target for _, target in progress_bar(dataset_object)]
725 num_classes = max(sample_classes) + 1
726 sample_counts = np.zeros(num_classes, dtype=np.int32)
727 for target in sample_classes:
728 sample_counts[target] += 1
729 #
730 train_class_weights = [float(num_samples) / float(cnt) for cnt in sample_counts]
731 train_sample_weights = [train_class_weights[target] for target in sample_classes]
732 dataset_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weights, epoch_size)
733 else:
734 dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
735 #
736 return dataset_sampler
739 def get_train_transform(args):
740 normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
741 if (args.image_mean is not None and args.image_scale is not None) else None
742 multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
744 train_resize_crop_transform = vision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
745 if args.rand_scale else vision.transforms.RandomCrop(size=args.img_crop)
746 train_transform = vision.transforms.Compose([train_resize_crop_transform,
747 vision.transforms.RandomHorizontalFlip(),
748 multi_color_transform,
749 vision.transforms.ToFloat(),
750 vision.transforms.ToTensor(),
751 normalize])
752 return train_transform
754 def get_validation_transform(args):
755 normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
756 if (args.image_mean is not None and args.image_scale is not None) else None
757 multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
759 # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
760 val_resize_crop_transform = vision.transforms.Resize(size=args.img_resize) if args.img_resize else vision.transforms.Bypass()
761 val_transform = vision.transforms.Compose([val_resize_crop_transform,
762 vision.transforms.CenterCrop(size=args.img_crop),
763 multi_color_transform,
764 vision.transforms.ToFloat(),
765 vision.transforms.ToTensor(),
766 normalize])
767 return val_transform
769 def get_transforms(args):
770 # Provision to train with val transform - provide rand_scale as (0, 0)
771 # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
772 always_use_val_transform = (args.rand_scale[0] == 0)
773 train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
774 val_transform = get_validation_transform(args)
775 return train_transform, val_transform
777 def get_data_loaders(args):
778 train_transform, val_transform = get_transforms(args)
780 train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))
782 if args.distributed:
783 train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
784 val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
785 else:
786 train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
787 val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
788 #
790 train_shuffle = args.shuffle and (train_sampler is None)
791 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers,
792 pin_memory=True, sampler=train_sampler)
794 val_shuffle = args.shuffle_val and (val_sampler is None)
795 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=val_shuffle, num_workers=args.workers,
796 pin_memory=True, drop_last=False, sampler=val_sampler)
798 return train_loader, val_loader
801 if __name__ == '__main__':
802 main()