]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/engine/train_classification.py
minor update to quantization docs & scritps. support for external model in classifica...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_classification.py
1 import os
2 import shutil
3 import time
5 import random
6 import numpy as np
7 from colorama import Fore
8 import math
9 import progiter
10 import warnings
12 import torch
13 import torch.nn.parallel
14 import torch.backends.cudnn as cudnn
15 import torch.distributed as dist
16 import torch.optim
17 import torch.utils.data
18 import torch.utils.data.distributed
20 import sys
21 import datetime
23 import onnx
24 from onnx import shape_inference
26 from .. import xnn
27 from .. import vision
30 #################################################
31 def get_config():
32     args = xnn.utils.ConfigNode()
33     args.model_config = xnn.utils.ConfigNode()
34     args.dataset_config = xnn.utils.ConfigNode()
35     args.model_config.num_tiles_x = int(1)
36     args.model_config.num_tiles_y = int(1)
37     args.model_config.en_make_divisible_by8 = True
39     args.model_config.input_channels = 3                # num input channels
41     args.data_path = './data/datasets/ilsvrc'           # path to dataset
42     args.model_name = 'mobilenetv2_tv_x1'     # model architecture'
43     args.model = None                                   #if mdoel is crated externaly 
44     args.dataset_name = 'imagenet_classification'       # image folder classification
45     args.save_path = None                               # checkpoints save path
46     args.phase = 'training'                             # training/calibration/validation
47     args.date = None                                    # date to add to save path. if this is None, current date will be added.
49     args.workers = 8                                    # number of data loading workers (default: 8)
50     args.logger = None                                  # logger stream to output into
52     args.epochs = 90                                    # number of total epochs to run
53     args.warmup_epochs = None                           # number of epochs to warm up by linearly increasing lr
55     args.epoch_size = 0                                 # fraction of training epoch to use each time. 0 indicates full
56     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
57     args.start_epoch = 0                                # manual epoch number to start
58     args.stop_epoch = None                              # manual epoch number to stop
59     args.batch_size = 256                               # mini_batch size (default: 256)
60     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
61     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
63     args.lr = 0.1                                       # initial learning rate
64     args.lr_clips = None                                # use args.lr itself if it is None
65     args.lr_calib = 0.05                                # lr for bias calibration
66     args.momentum = 0.9                                 # momentum
67     args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
68     args.bias_decay = None                              # bias decay (default: 0.0)
70     args.shuffle = True                                 # shuffle or not
71     args.shuffle_val = True                             # shuffle val dataset or not
73     args.rand_seed = 1                                  # random seed
74     args.print_freq = 100                               # print frequency (default: 100)
75     args.resume = None                                  # path to latest checkpoint (default: none)
76     args.evaluate_start = True                          # evaluate right at the begining of training or not
77     args.world_size = 1                                 # number of distributed processes
78     args.dist_url = 'tcp://224.66.41.62:23456'          # url used to set up distributed training
79     args.dist_backend = 'gloo'                          # distributed backend
81     args.optimizer = 'sgd'                              # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
82     args.scheduler = 'step'                             # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
83     args.milestones = (30, 60, 90)                      # epochs at which learning rate is divided
84     args.multistep_gamma = 0.1                          # multi step gamma (default: 0.1)
85     args.polystep_power = 1.0                           # poly step gamma (default: 1.0)
86     args.step_size = 1,                                 # step size for exp lr decay
88     args.beta = 0.999                                   # beta parameter for adam
89     args.pretrained = None                              # path to pre_trained model
90     args.img_resize = 256                               # image resize
91     args.img_crop = 224                                 # image crop
92     args.rand_scale = (0.08,1.0)                        # random scale range for training
93     args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
94     args.count_flops = True                             # count flops and report
96     args.save_onnx = True                           # apply quantized inference or not
97     args.print_model = False                            # print the model to text
98     args.run_soon = True                                # Set to false if only cfs files/onnx  modelsneeded but no training
100     args.multi_color_modes = None                       # input modes created with multi color transform
101     args.image_mean = (123.675, 116.28, 103.53)         # image mean for input image normalization')
102     args.image_scale = (0.017125, 0.017507, 0.017429)   # image scaling/mult for input iamge normalization')
104     args.parallel_model = True                          # Usedata parallel for model
106     args.quantize = False                               # apply quantized inference or not
107     #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
108     args.bitwidth_weights = 8                           # bitwidth for weights
109     args.bitwidth_activations = 8                       # bitwidth for activations
110     args.histogram_range = True                         # histogram range for calibration
111     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
112     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
114     args.freeze_bn = False                              # freeze the statistics of bn
115     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
117     args.opset_version = 9                              # onnx opset_version
118     return args
121 #################################################
122 cudnn.benchmark = True
123 #cudnn.enabled = False
127 def main(args):
128     assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
129     assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
130     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
132     if (args.phase == 'validation' and args.bias_calibration):
133         args.bias_calibration = False
134         warnings.warn('switching off bias calibration in validation')
135     #
137     #################################################
138     if args.save_path is None:
139         save_path = get_save_path(args)
140     else:
141         save_path = args.save_path
142     #
144     args.best_prec1 = -1
146     # resume has higher priority
147     args.pretrained = None if (args.resume is not None) else args.pretrained
149     if not os.path.exists(save_path):
150         os.makedirs(save_path)
152     if args.save_mod_files:
153         #store all the files after the last commit.
154         mod_files_path = save_path+'/mod_files'
155         os.makedirs(mod_files_path)
156         
157         cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
158         print("cmd:", cmd)    
159         os.system(cmd)
161         #stoe last commit id. 
162         cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
163         print("cmd:", cmd)    
164         os.system(cmd)
165     #################################################
166     if args.logger is None:
167         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
168         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
170     ################################
171     args.pretrained = None if (args.pretrained == 'None') else args.pretrained
172     args.num_inputs = len(args.multi_color_modes) if (args.multi_color_modes is not None) else 1
174     if args.iter_size != 1 and args.total_batch_size is not None:
175         warnings.warn("only one of --iter_size or --total_batch_size must be set")
176     #
177     if args.total_batch_size is not None:
178         args.iter_size = args.total_batch_size//args.batch_size
179     else:
180         args.total_batch_size = args.batch_size*args.iter_size
182     args.stop_epoch = args.stop_epoch if args.stop_epoch else args.epochs
184     args.distributed = args.world_size > 1
186     if args.distributed:
187         dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
188                                 world_size=args.world_size)
190     #################################################
191     # global settings. rand seeds for repeatability
192     random.seed(args.rand_seed)
193     np.random.seed(args.rand_seed)
194     torch.manual_seed(args.rand_seed)
195     torch.backends.cudnn.deterministic = True
196     torch.backends.cudnn.benchmark = True
197     # torch.autograd.set_detect_anomaly(True)
199     ################################
200     # print everything for log
201     # reset character color, in case it is different
202     print('{}'.format(Fore.RESET))
203     print("=> args: ", args)
204     print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
205     print("=> resize resolution: {}".format(args.img_resize))
206     print("=> crop resolution  : {}".format(args.img_crop))
207     sys.stdout.flush()
209     #################################################
210     pretrained_data = None
211     model_surgery_quantize = False
212     if args.pretrained and args.pretrained != "None":
213         if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
214             pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
215         else:
216             pretrained_file = args.pretrained
217         #
218         print(f'=> using pre-trained weights from: {args.pretrained}')
219         pretrained_data = torch.load(pretrained_file)
220         model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
221     #
223     #################################################
224     # create model
225     print("=> creating model '{}'".format(args.model_name))
226     
227     model = vision.models.classification.__dict__[args.model_name](args.model_config) if args.model == None else args.model
229     # check if we got the model as well as parameters to change the names in pretrained
230     model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
232     #################################################
233     if args.quantize:
234         # dummy input is used by quantized models to analyze graph
235         is_cuda = next(model.parameters()).is_cuda
236         dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
237         #
238         if 'training' in args.phase:
239             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
240                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
241                         bitwidth_activations=args.bitwidth_activations,
242                         dummy_input=dummy_input)
243         elif 'calibration' in args.phase:
244             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
245                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
246                         histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
247                         lr_calib=args.lr_calib)
248         elif 'validation' in args.phase:
249             # Note: bias_calibration is not used in test
250             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
251                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
252                         histogram_range=args.histogram_range, dummy_input=dummy_input,
253                         model_surgery_quantize=model_surgery_quantize)
254         else:
255             assert False, f'invalid phase {args.phase}'
256     #
258     # load pretrained
259     if pretrained_data is not None:
260         xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
261     #
262     
263     #################################################
264     if args.count_flops:
265         count_flops(args, model)
267     #################################################
268     if args.save_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)):
269         write_onnx_model(args, get_model_orig(model), save_path)
270     #
272     #################################################
273     if args.print_model:
274         print(model)
275     else:
276         args.logger.debug(str(model))
278     #################################################
279     if (not args.run_soon):
280         print("Training not needed for now")
281         close(args)
282         exit()
284     #################################################
285     # multi gpu mode is not working for quantized model
286     if args.parallel_model and (not args.quantize):
287         if args.distributed:
288             model = torch.nn.parallel.DistributedDataParallel(model)
289         else:
290             model = torch.nn.DataParallel(model)
291     #
293     #################################################
294     model = model.cuda()
296     #################################################
297     # define loss function (criterion) and optimizer
298     criterion = torch.nn.CrossEntropyLoss().cuda()
300     model_module = model.module if hasattr(model, 'module') else model
301     if args.lr_clips is not None:
302         learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
303         clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
304         clips_params = [p for n,p in model_module.named_parameters() if 'clips' in n]
305         other_params = [p for n,p in model_module.named_parameters() if 'clips' not in n]
306         param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
307                         {'params': other_params, 'weight_decay': args.weight_decay}]
308     else:
309         param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
310     #
312     print("=> args: ", args)          
313     print("=> optimizer type   : {}".format(args.optimizer))
314     print("=> learning rate    : {}".format(args.lr))
315     print("=> resize resolution: {}".format(args.img_resize))
316     print("=> crop resolution  : {}".format(args.img_crop))
317     print("=> batch size       : {}".format(args.batch_size))
318     print("=> total batch size : {}".format(args.total_batch_size))
319     print("=> epoch size       : {}".format(args.epoch_size))
320     print("=> data augument    : {}".format(args.data_augument))
321     print("=> epochs           : {}".format(args.epochs))
322     if args.scheduler == 'step':
323         print("=> milestones       : {}".format(args.milestones))
325     learning_rate = args.lr if ('training'in args.phase) else 0.0
326     if args.optimizer == 'adam':
327         optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
328     elif args.optimizer == 'sgd':
329         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
330     elif args.optimizer == 'sgd_nesterov':
331         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum, nesterov=True)
332     elif args.optimizer == 'rmsprop':
333         optimizer = torch.optim.RMSprop(param_groups, learning_rate, momentum=args.momentum)
334     else:
335         raise ValueError('Unknown optimizer type{}'.format(args.optimizer))
336         
337     # optionally resume from a checkpoint
338     if args.resume:
339         if os.path.isfile(args.resume):
340             print("=> resuming from checkpoint '{}'".format(args.resume))
341             checkpoint = torch.load(args.resume)
342             if args.start_epoch == 0:
343                 args.start_epoch = checkpoint['epoch'] + 1
344                 
345             args.best_prec1 = checkpoint['best_prec1']
346             model = xnn.utils.load_weights(model, checkpoint)
347             optimizer.load_state_dict(checkpoint['optimizer'])
348             print("=> resuming from checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
349         else:
350             print("=> no checkpoint found at '{}'".format(args.resume))
352     train_loader, val_loader = get_data_loaders(args)
354     args.cur_lr = adjust_learning_rate(args, optimizer, args.start_epoch)
356     if args.evaluate_start or args.phase=='validation':
357         validate(args, val_loader, model, criterion, args.start_epoch)
359     if args.phase == 'validation':
360         close(args)
361         return
363     for epoch in range(args.start_epoch, args.stop_epoch):
364         if args.distributed:
365             train_loader.sampler.set_epoch(epoch)
367         # train for one epoch
368         train(args, train_loader, model, criterion, optimizer, epoch)
370         # evaluate on validation set
371         prec1 = validate(args, val_loader, model, criterion, epoch)
373         # remember best prec@1 and save checkpoint
374         is_best = prec1 > args.best_prec1
375         args.best_prec1 = max(prec1, args.best_prec1)
377         model_orig = get_model_orig(model)
379         save_dict = {'epoch': epoch, 'arch': args.model_name, 'state_dict': model_orig.state_dict(),
380                      'best_prec1': args.best_prec1, 'optimizer' : optimizer.state_dict(),
381                      'quantize' : args.quantize}
383         save_checkpoint(args, save_path, model_orig, save_dict, is_best)
384     #
386     # for n, m in model.named_modules():
387     #     if hasattr(m, 'num_batches_tracked'):
388     #         print(f'name={n}, num_batches_tracked={m.num_batches_tracked}')
389     # #
391     # close and cleanup
392     close(args)
395 ###################################################################
396 def is_valid_phase(phase):
397     phases = ('training', 'calibration', 'validation')
398     return any(p in phase for p in phases)
401 def close(args):
402     if args.logger is not None:
403         del args.logger
404         args.logger = None
405     #
406     args.best_prec1 = -1
409 def get_save_path(args, phase=None):
410     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
411     save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
412     save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
413     phase = phase if (phase is not None) else args.phase
414     save_path = os.path.join(save_path, phase)
415     return save_path
418 def get_model_orig(model):
419     is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
420     model_orig = (model.module if is_parallel_model else model)
421     model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
422     return model_orig
425 def create_rand_inputs(args, is_cuda):
426     dummy_input = torch.rand((1, args.model_config.input_channels, args.img_crop*args.model_config.num_tiles_y,
427       args.img_crop*args.model_config.num_tiles_x))
428     dummy_input = dummy_input.cuda() if is_cuda else dummy_input
429     return dummy_input
432 def count_flops(args, model):
433     is_cuda = next(model.parameters()).is_cuda
434     dummy_input = create_rand_inputs(args, is_cuda)
435     #
436     model.eval()
437     flops = xnn.utils.forward_count_flops(model, dummy_input)
438     gflops = flops/1e9
439     print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
442 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
443     is_cuda = next(model.parameters()).is_cuda
444     dummy_input = create_rand_inputs(args, is_cuda)
445     #
446     model.eval()
447     torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False,
448                       do_constant_folding=True, opset_version=args.opset_version)
449     
450     #to see tensor shape in ONNX graph. Works only upto ver 8
451     if args.opset_version <= 8:
452         path = os.path.join(save_path,name)                       
453         onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
456 def train(args, train_loader, model, criterion, optimizer, epoch):
457     # actual training code
458     batch_time = AverageMeter()
459     data_time = AverageMeter()
460     losses = AverageMeter()
461     top1 = AverageMeter()
462     top5 = AverageMeter()
464     # switch to train mode
465     model.train()
466     if args.freeze_bn:
467         xnn.utils.freeze_bn(model)
468     #
470     num_iters = len(train_loader)
471     progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
472     args.cur_lr = adjust_learning_rate(args, optimizer, epoch)
474     end = time.time()
475     last_update_iter = -1
477     progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
478     print('{}'.format(progressbar_color), end='')
480     for iteration, (input, target) in enumerate(train_loader):
481         input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
482         input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
483         target = target.cuda(non_blocking=True)
485         data_time.update(time.time() - end)
487         # preprocess to make tiles
488         if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
489             input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
490         #
492         # compute output
493         output = model(input)
495         if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
496             # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class]
497             # e.g. [1,10,4,5] to [1,4,5,10]
498             output = output.permute(0, 2, 3, 1)
499             #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
500             output = torch.reshape(output, (-1, output.shape[-1]))
501         #
503         # compute loss
504         loss = criterion(output, target) / args.iter_size
506         # measure accuracy and record loss
507         prec1, prec5 = accuracy(output, target, topk=(1, 5))
508         losses.update(loss.item(), input_size[0])
509         top1.update(prec1[0], input_size[0])
510         top5.update(prec5[0], input_size[0])
512         if 'training' in args.phase:
513             # zero gradients so that we can accumulate gradients
514             if (iteration % args.iter_size) == 0:
515                 optimizer.zero_grad()
517             loss.backward()
519             if ((iteration+1) % args.iter_size) == 0:
520                 optimizer.step()
521         #
523         # measure elapsed time
524         batch_time.update(time.time() - end)
525         end = time.time()
526         final_iter = (iteration >= (num_iters-1))
528         if ((iteration % args.print_freq) == 0) or final_iter:
529             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
530             status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} DataTime={data_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
531                          .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)
533             progress_bar.set_description(f'=> {args.phase}  ')
534             progress_bar.set_postfix(Epoch='{}'.format(status_str))
535             progress_bar.update(iteration-last_update_iter)
536             last_update_iter = iteration
537         #
538     #
539     progress_bar.close()
541     # to print a new line - do not provide end=''
542     print('{}'.format(Fore.RESET), end='')
544     ##########################
545     if args.quantize:
546         def debug_format(v):
547             return ('{:.3f}'.format(v) if v is not None else 'None')
548         #
549         clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
550         if len(clips_act) > 0:
551             args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
552             args.logger.debug('')
553     #
557 def validate(args, val_loader, model, criterion, epoch):
558     batch_time = AverageMeter()
559     losses = AverageMeter()
560     top1 = AverageMeter()
561     top5 = AverageMeter()
563     # switch to evaluate mode
564     model.eval()
566     num_iters = len(val_loader)
567     progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
568     last_update_iter = -1
570     # change color to green
571     print('{}'.format(Fore.GREEN), end='')
573     with torch.no_grad():
574         end = time.time()
575         for iteration, (input, target) in enumerate(val_loader):
576             input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
577             input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
578             target = target.cuda(non_blocking=True)
580             # preprocess to make tiles
581             if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
582                 input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
583             #
585             # compute output
586             output = model(input)
588             if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
589                 # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class] 
590                 # e.g. [1,10,4,5] to [1,4,5,10]
591                 output = output.permute(0,2,3,1)
592                 #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
593                 output = torch.reshape(output, (-1, output.shape[-1]))
594             #
596             loss = criterion(output, target)
598             # measure accuracy and record loss
599             prec1, prec5 = accuracy(output, target, topk=(1, 5))
600             losses.update(loss.item(), input_size[0])
601             top1.update(prec1[0], input_size[0])
602             top5.update(prec5[0], input_size[0])
604             # measure elapsed time
605             batch_time.update(time.time() - end)
606             end = time.time()
607             final_iter = (iteration >= (num_iters-1))
609             if ((iteration % args.print_freq) == 0) or final_iter:
610                 epoch_str = '{}/{}'.format(epoch+1,args.epochs)
611                 status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
612                              .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, loss=losses, top1=top1, top5=top5)
613                            
614                 prefix = '**' if final_iter else '=>'
615                 progress_bar.set_description('{} {}'.format(prefix, 'validation'))
616                 progress_bar.set_postfix(Epoch='{}'.format(status_str))
617                 progress_bar.update(iteration - last_update_iter)
618                 last_update_iter = iteration
619             #
620         #
622         progress_bar.close()
624         # to print a new line - do not provide end=''
625         print('{}'.format(Fore.RESET), end='')
627     return top1.avg
630 def save_checkpoint(args, save_path, model, state, is_best, filename='checkpoint.pth'):
631     filename = os.path.join(save_path, filename)
632     torch.save(state, filename)
633     if is_best:
634         bestname = os.path.join(save_path, 'model_best.pth')
635         shutil.copyfile(filename, bestname)
636     #
637     if args.save_onnx:
638         write_onnx_model(args, model, save_path, name='checkpoint.onnx')
639         if is_best:
640             write_onnx_model(args, model, save_path, name='model_best.onnx')
644 class AverageMeter(object):
645     """Computes and stores the average and current value"""
646     def __init__(self):
647         self.reset()
649     def reset(self):
650         self.val = 0
651         self.avg = 0
652         self.sum = 0
653         self.count = 0
655     def update(self, val, n=1):
656         self.val = val
657         self.sum += val * n
658         self.count += n
659         self.avg = self.sum / self.count
662 def adjust_learning_rate(args, optimizer, epoch):
663     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
664     cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
666     if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
667         cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
668     elif args.scheduler == 'poly':
669         epoch_frac = (args.epochs - epoch) / args.epochs
670         epoch_frac = max(epoch_frac, 0)
671         cur_lr = args.lr * (epoch_frac ** args.polystep_power)
672         for param_group in optimizer.param_groups:
673             param_group['lr'] = cur_lr
674         #
675     elif args.scheduler == 'step':                                            # step
676         num_milestones = 0
677         for m in args.milestones:
678             num_milestones += (1 if epoch >= m else 0)
679         #
680         cur_lr = args.lr * (args.multistep_gamma ** num_milestones)
681     elif args.scheduler == 'exponential':                                   # exponential
682         cur_lr = args.lr * (args.multistep_gamma ** (epoch//args.step_size))
683     elif args.scheduler == 'cosine':                                        # cosine
684         if epoch == 0:
685             cur_lr = args.lr
686         else:
687             lr_min = 0
688             cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0  + lr_min
689         #
690     else:
691         ValueError('Unknown scheduler {}'.format(args.scheduler))
692     #
693     for param_group in optimizer.param_groups:
694         param_group['lr'] = cur_lr
695     #
696     return cur_lr
699 def accuracy(output, target, topk=(1,)):
700     """Computes the precision@k for the specified values of k"""
701     with torch.no_grad():
702         maxk = max(topk)
703         batch_size = target.size(0)
705         _, pred = output.topk(maxk, 1, True, True)
706         pred = pred.t()
707         correct = pred.eq(target.view(1, -1).expand_as(pred))
709         res = []
710         for k in topk:
711             correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
712             res.append(correct_k.mul_(100.0 / batch_size))
713         return res
716 def get_dataset_sampler(dataset_object, epoch_size, balanced_sampler=False):
717     num_samples = len(dataset_object)
718     epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
719     print('=> creating a random sampler as epoch_size is specified')
720     if balanced_sampler:
721         # going through the dataset this way may take too much time
722         progress_bar = progiter.ProgIter(np.arange(num_samples), chunksize=1, \
723             desc='=> reading data to create a balanced data sampler : ')
724         sample_classes = [target for _, target in progress_bar(dataset_object)]
725         num_classes = max(sample_classes) + 1
726         sample_counts = np.zeros(num_classes, dtype=np.int32)
727         for target in sample_classes:
728             sample_counts[target] += 1
729         #
730         train_class_weights = [float(num_samples) / float(cnt) for cnt in sample_counts]
731         train_sample_weights = [train_class_weights[target] for target in sample_classes]
732         dataset_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weights, epoch_size)
733     else:
734         dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
735     #
736     return dataset_sampler
737     
739 def get_train_transform(args):
740     normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
741         if (args.image_mean is not None and args.image_scale is not None) else None
742     multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
744     train_resize_crop_transform = vision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
745         if args.rand_scale else vision.transforms.RandomCrop(size=args.img_crop)
746     train_transform = vision.transforms.Compose([train_resize_crop_transform,
747                                                  vision.transforms.RandomHorizontalFlip(),
748                                                  multi_color_transform,
749                                                  vision.transforms.ToFloat(),
750                                                  vision.transforms.ToTensor(),
751                                                  normalize])
752     return train_transform
754 def get_validation_transform(args):
755     normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
756         if (args.image_mean is not None and args.image_scale is not None) else None
757     multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
759     # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
760     val_resize_crop_transform = vision.transforms.Resize(size=args.img_resize) if args.img_resize else vision.transforms.Bypass()
761     val_transform = vision.transforms.Compose([val_resize_crop_transform,
762                                                vision.transforms.CenterCrop(size=args.img_crop),
763                                                multi_color_transform,
764                                                vision.transforms.ToFloat(),
765                                                vision.transforms.ToTensor(),
766                                                normalize])
767     return val_transform
769 def get_transforms(args):
770     # Provision to train with val transform - provide rand_scale as (0, 0)
771     # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
772     always_use_val_transform = (args.rand_scale[0] == 0)
773     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
774     val_transform = get_validation_transform(args)
775     return train_transform, val_transform
777 def get_data_loaders(args):
778     train_transform, val_transform = get_transforms(args)
780     train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))
782     if args.distributed:
783         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
784         val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
785     else:
786         train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
787         val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
788     #
790     train_shuffle = args.shuffle and (train_sampler is None)
791     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers,
792                                                pin_memory=True, sampler=train_sampler)
794     val_shuffle = args.shuffle_val and (val_sampler is None)
795     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=val_shuffle, num_workers=args.workers,
796                                              pin_memory=True, drop_last=False, sampler=val_sampler)
798     return train_loader, val_loader
801 if __name__ == '__main__':
802     main()