]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - scripts/train_depth_main.py
improved speed in training pixel2pixel models, added unet, other fixes
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / scripts / train_depth_main.py
index e5db9066de4452c241e1ef8ce62bb2cde7c81fd4..acf1ad5b5ec3446503d450fe6481a020b4b61951 100755 (executable)
@@ -46,8 +46,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 ################################