doc update - to clarify the use of model.train() and model.eval()
authorManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 03:50:30 +0000 (09:20 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:56:53 +0000 (13:26 +0530)
docs/Calibration.md
docs/Quantization.md

index 247743afb2dcc53b0d416c9ec3c44ed900d4edc7..2c97bb95680856e5fba7a820b181cf8b746eca5c 100644 (file)
@@ -56,13 +56,17 @@ model.module.load_state_dict(pretrained_data)
 my_dataset_train, my_dataset_val = ...
 
 # do one epoch of calibration - in practice about 1000 iterations are sufficient.
-for images, target in my_dataset_train:
+model.train()
+for images, targets in my_dataset_train:
     output = model(images)
+    # calibration doesn't need anything else here - not even the loss function.
+    # so the targets are also not needed.
 
 # save the model - the calibrated module is in model.module
-torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
+# calibrated model can export a clean onnx graph with clips in eval mode.
+model.eval()
 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False, do_constant_folding=True, opset_version=9)
-
+torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
 ```
 
 Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
index 5d9e0e24a8dbf20583313795e895014ceb80653a..7c6dc13e5d4f40f8c9d2a20c783ce3e28d84e53e 100644 (file)
@@ -88,13 +88,17 @@ pretrained_data = torch.load(pretrained_path)
 model.module.load_state_dict(pretrained_data)
 
 # your training loop here with with loss, backward, optimizer and scheduler. 
-# this is the usual training loop - but use a lower learning rate such as 5e-5
-....
-....
+# this is the usual training loop - but use a lower learning rate such as 1e-5
+model.train()
+for images, target in my_dataset_train:
+    output = model(images)
+    # loss, backward(), optimizer step etc comes here as usual in training
 
 # save the model - the trained module is in model.module
-torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
+# QAT model can export a clean onnx graph with clips in eval mode.
+model.eval()
 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False, do_constant_folding=True, opset_version=9)
+torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
 ```
 
 As can be seen, it is easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the parameters of the pretrained model is loaded and trained model is saved as shown in the above code snippet.