release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / write_onnx_model_example.py
index 64f2f8ed157166b59acded0b1e7a90f90be9a738..be09729ac1021d2782141fdac90e8205893edb7a 100644 (file)
@@ -1,7 +1,8 @@
 import os
 import torch
-import torchvision
 import datetime
+import torchvision as vision
+# from pytorch_jacinto_ai import vision
 
 # dependencies
 # Anaconda Python 3.7 for Linux - download and install from: https://www.anaconda.com/distribution/
@@ -11,21 +12,34 @@ import datetime
 # some parameters - modify as required
 date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
 dataset_name = 'image_folder_classification'
-model_name = 'resnet50'
+model_names = ['mobilenet_v2', 'resnet18', 'resnet50', 'resnext50_32x4d', 'shufflenet_v2_x1_0']
 img_resize = (256,256)
 rand_crop = (224,224)
 
 # the saving path - you can choose any path
 save_path = './data/checkpoints'
-save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name + '_' + model_name)
+save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name)
 save_path += '_resize{}x{}_traincrop{}x{}'.format(img_resize[1], img_resize[0], rand_crop[1], rand_crop[0])
 os.makedirs(save_path, exist_ok=True)
 
-# create the model - replace with your model
-model = torchvision.models.resnet50(pretrained=True)
-
 # create a rand input
 rand_input = torch.rand(1, 3, rand_crop[0], rand_crop[1])
 
-# write the onnx model
-torch.onnx.export(model, rand_input, os.path.join(save_path, 'model.onnx'), export_params=True, verbose=False)
+for model_name in model_names:
+    # create the model - replace with your model
+    model = vision.models.__dict__[model_name](pretrained=True)
+    model.eval()
+
+    # write pytorch model
+    model_path=os.path.join(save_path, f'{model_name}_model.pth')
+    traced_model = torch.jit.trace(model, rand_input)
+    torch.jit.save(traced_model, model_path)
+
+    # write pytorch sate dict
+    model_path=os.path.join(save_path, f'{model_name}_state_dict.pth')
+    torch.save(model.state_dict(), model_path)
+
+    # write the onnx model
+    opset_version=9
+    model_path=os.path.join(save_path, f'{model_name}_opset{opset_version}.onnx')
+    torch.onnx.export(model, rand_input, model_path, export_params=True, verbose=False, opset_version=opset_version)