[scripts] A cosmetic change to info messages in chain training (#1880)
authorHossein Hadian <hn.hadian@gmail.com>
Mon, 11 Sep 2017 17:57:29 +0000 (12:57 -0500)
committerDaniel Povey <dpovey@gmail.com>
Mon, 11 Sep 2017 17:57:29 +0000 (13:57 -0400)
egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py
egs/wsj/s5/steps/nnet3/chain/train.py

index fedce12dda0736334362fcf00fdba2c74e130a7f..52d97d9a0be6f4786dba2dab4047bd7123957942 100644 (file)
@@ -240,8 +240,6 @@ def train_one_iteration(dir, iter, srand, egs_dir,
 
     # Set off jobs doing some diagnostics, in the background.
     # Use the egs dir from the previous iteration for the diagnostics
-    logger.info("Training neural net (pass {0})".format(iter))
-
     # check if different iterations use the same random seed
     if os.path.exists('{0}/srand'.format(dir)):
         try:
@@ -290,16 +288,6 @@ def train_one_iteration(dir, iter, srand, egs_dir,
         cur_max_param_change = float(max_param_change) / math.sqrt(2)
 
     raw_model_string = raw_model_string + dropout_edit_string
-
-    shrink_info_str = ''
-    if shrinkage_value != 1.0:
-        shrink_info_str = ' and shrink value is {0}'.format(shrinkage_value)
-
-    logger.info("On iteration {0}, learning rate is {1}"
-                "{shrink_info}.".format(
-                    iter, learning_rate,
-                    shrink_info=shrink_info_str))
-
     train_new_models(dir=dir, iter=iter, srand=srand, num_jobs=num_jobs,
                      num_archives_processed=num_archives_processed,
                      num_archives=num_archives,
index 6f9452c457c4cb652722d5588c431077c26ee8fe..55c0c25dfd55b2f83e82e9d27144fdf6a044e4ac 100755 (executable)
@@ -450,6 +450,19 @@ def train(args, run_opts):
                                         args.shrink_saturation_threshold)
                                    else shrinkage_value)
 
+            percent = num_archives_processed * 100.0 / num_archives_to_process
+            epoch = (num_archives_processed * args.num_epochs
+                     / num_archives_to_process)
+            shrink_info_str = ''
+            if shrinkage_value != 1.0:
+                shrink_info_str = 'shrink: {0:0.5f}'.format(shrinkage_value)
+            logger.info("Iter: {0}/{1}    "
+                        "Epoch: {2:0.2f}/{3:0.1f} ({4:0.1f}% complete)    "
+                        "lr: {5:0.6f}    {6}".format(iter, num_iters - 1,
+                                                     epoch, args.num_epochs,
+                                                     percent,
+                                                     lrate, shrink_info_str))
+
             chain_lib.train_one_iteration(
                 dir=args.dir,
                 iter=iter,