import os import torch import datetime import torchvision as vision # from pytorch_jacinto_ai import vision # dependencies # Anaconda Python 3.7 for Linux - download and install from: https://www.anaconda.com/distribution/ # pytorch, torchvision - install using: # conda install pytorch torchvision -c pytorch # some parameters - modify as required date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") dataset_name = 'image_folder_classification' model_names = ['mobilenet_v2', 'resnet18', 'resnet50', 'resnext50_32x4d', 'shufflenet_v2_x1_0'] img_resize = (256,256) rand_crop = (224,224) # the saving path - you can choose any path save_path = './data/checkpoints' save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name) save_path += '_resize{}x{}_traincrop{}x{}'.format(img_resize[1], img_resize[0], rand_crop[1], rand_crop[0]) os.makedirs(save_path, exist_ok=True) # create a rand input rand_input = torch.rand(1, 3, rand_crop[0], rand_crop[1]) for model_name in model_names: # create the model - replace with your model model = vision.models.__dict__[model_name](pretrained=True) model.eval() # write pytorch model model_path=os.path.join(save_path, f'{model_name}_model.pth') traced_model = torch.jit.trace(model, rand_input) torch.jit.save(traced_model, model_path) # write pytorch sate dict model_path=os.path.join(save_path, f'{model_name}_state_dict.pth') torch.save(model.state_dict(), model_path) # write the onnx model opset_version=9 model_path=os.path.join(save_path, f'{model_name}_opset{opset_version}.onnx') torch.onnx.export(model, rand_input, model_path, export_params=True, verbose=False, opset_version=opset_version)