quantization_example - RandomSampler is used when epoch_size!=0. epoch_size=0.1 means...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / write_onnx_model_example.py
1 import os
2 import torch
3 import datetime
4 import torchvision as vision
5 # from pytorch_jacinto_ai import vision
7 # dependencies
8 # Anaconda Python 3.7 for Linux - download and install from: https://www.anaconda.com/distribution/
9 # pytorch, torchvision - install using: 
10 # conda install pytorch torchvision -c pytorch
12 # some parameters - modify as required
13 date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
14 dataset_name = 'image_folder_classification'
15 model_names = ['mobilenet_v2', 'resnet18', 'resnet50', 'resnext50_32x4d', 'shufflenet_v2_x1_0']
16 img_resize = (256,256)
17 rand_crop = (224,224)
18 opset_version = 9
20 # the saving path - you can choose any path
21 save_path = './data/checkpoints'
22 save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name)
23 save_path += '_resize{}x{}_traincrop{}x{}'.format(img_resize[1], img_resize[0], rand_crop[1], rand_crop[0])
24 os.makedirs(save_path, exist_ok=True)
26 # create a rand input
27 rand_input = torch.rand(1, 3, rand_crop[0], rand_crop[1])
29 for model_name in model_names:
30     # create the model - replace with your model
31     model = vision.models.__dict__[model_name](pretrained=True)
32     model.eval()
34     # write pytorch model
35     model_path=os.path.join(save_path, f'{model_name}_model.pth')
36     traced_model = torch.jit.trace(model, rand_input)
37     torch.jit.save(traced_model, model_path)
39     # write pytorch sate dict
40     model_path=os.path.join(save_path, f'{model_name}_state_dict.pth')
41     torch.save(model.state_dict(), model_path)
43     # write the onnx model
44     model_path=os.path.join(save_path, f'{model_name}_opset{opset_version}.onnx')
45     torch.onnx.export(model, rand_input, model_path, export_params=True, verbose=False,
46                       do_constant_folding=True, opset_version=opset_version)