]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - examples/write_onnx_model_example.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / write_onnx_model_example.py
1 # Copyright (c) 2018-2021, Texas Instruments
2 # All Rights Reserved.
3 #
4 # Redistribution and use in source and binary forms, with or without
5 # modification, are permitted provided that the following conditions are met:
6 #
7 # * Redistributions of source code must retain the above copyright notice, this
8 #   list of conditions and the following disclaimer.
9 #
10 # * Redistributions in binary form must reproduce the above copyright notice,
11 #   this list of conditions and the following disclaimer in the documentation
12 #   and/or other materials provided with the distribution.
13 #
14 # * Neither the name of the copyright holder nor the names of its
15 #   contributors may be used to endorse or promote products derived from
16 #   this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 import os
30 import torch
31 import datetime
32 import torchvision as xvision
33 # from pytorch_jacinto_ai import xvision
35 # dependencies
36 # Python 3.7 (might work in other versions as well)
37 # pytorch, torchvision - install using: 
38 # conda install pytorch torchvision -c pytorch
40 # some parameters - modify as required
41 date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
42 dataset_name = 'image_folder_classification'
43 model_names = ['mobilenet_v2', 'resnet18', 'resnet50', 'resnext50_32x4d', 'shufflenet_v2_x1_0']
44 img_resize = (256,256)
45 rand_crop = (224,224)
46 opset_version = 9
48 # the saving path - you can choose any path
49 save_path = './data/checkpoints'
50 save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name)
51 save_path += '_resize{}x{}_traincrop{}x{}'.format(img_resize[1], img_resize[0], rand_crop[1], rand_crop[0])
52 os.makedirs(save_path, exist_ok=True)
54 # create a rand input
55 rand_input = torch.rand(1, 3, rand_crop[0], rand_crop[1])
57 for model_name in model_names:
58     # create the model - replace with your model
59     model = xvision.models.__dict__[model_name](pretrained=True)
60     model.eval()
62     # write pytorch model
63     model_path=os.path.join(save_path, f'{model_name}_model.pth')
64     traced_model = torch.jit.trace(model, rand_input)
65     torch.jit.save(traced_model, model_path)
67     # write pytorch sate dict
68     model_path=os.path.join(save_path, f'{model_name}_state_dict.pth')
69     torch.save(model.state_dict(), model_path)
71     # write the onnx model
72     model_path=os.path.join(save_path, f'{model_name}_opset{opset_version}.onnx')
73     torch.onnx.export(model, rand_input, model_path, export_params=True, verbose=False,
74                       do_constant_folding=True, opset_version=opset_version)