diff --git a/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py b/egs/wsj/s5/steps/libs/nnet3/report/log_parse.py
index 88a77d4d2d01d3c3709dd5156723c756e4a98684..1341ae2e936c75fedaf9b4fd792a46ee4a108de2 100755 (executable)
# Apache 2.0.
from __future__ import division
+from __future__ import print_function
+import traceback
import datetime
import logging
import re
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
+g_lstmp_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
+ "type=(.*)Component,.*",
+ "i_t_sigmoid.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "f_t_sigmoid.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "c_t_tanh.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "o_t_sigmoid.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "m_t_tanh.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])
+
+
+g_normal_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
+ "type=(.*)Component,.*",
+ "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
+ "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])
+
+class KaldiLogParseException(Exception):
+ """ An Exception class that throws an error when there is an issue in
+ parsing the log files. Extend this class if more granularity is needed.
+ """
+ def __init__(self, message = None):
+ if message is not None and message.strip() == "":
+ message = None
+
+ Exception.__init__(self,
+ "There was an error while trying to parse the logs."
+ " Details : \n{0}\n".format(message))
+
+# This function is used to fill stats_per_component_per_iter table with the
+# results of regular expression.
+def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table):
+ iteration = int(groups[0])
+ component_name = groups[1]
+ component_type = groups[2]
+ value_percentiles = groups[3+gate_index*6]
+ value_mean = float(groups[4+gate_index*6])
+ value_stddev = float(groups[5+gate_index*6])
+ value_percentiles_split = re.split(',| ',value_percentiles)
+ assert len(value_percentiles_split) == 13
+ value_5th = float(value_percentiles_split[4])
+ value_50th = float(value_percentiles_split[6])
+ value_95th = float(value_percentiles_split[9])
+ deriv_percentiles = groups[6+gate_index*6]
+ deriv_mean = float(groups[7+gate_index*6])
+ deriv_stddev = float(groups[8+gate_index*6])
+ deriv_percentiles_split = re.split(',| ',deriv_percentiles)
+ assert len(deriv_percentiles_split) == 13
+ deriv_5th = float(deriv_percentiles_split[4])
+ deriv_50th = float(deriv_percentiles_split[6])
+ deriv_95th = float(deriv_percentiles_split[9])
+ try:
+ if stats_table[component_name]['stats'].has_key(iteration):
+ stats_table[component_name]['stats'][iteration].extend(
+ [value_mean, value_stddev,
+ deriv_mean, deriv_stddev,
+ value_5th, value_50th, value_95th,
+ deriv_5th, deriv_50th, deriv_95th])
+ else:
+ stats_table[component_name]['stats'][iteration] = [
+ value_mean, value_stddev,
+ deriv_mean, deriv_stddev,
+ value_5th, value_50th, value_95th,
+ deriv_5th, deriv_50th, deriv_95th]
+ except KeyError:
+ stats_table[component_name] = {}
+ stats_table[component_name]['type'] = component_type
+ stats_table[component_name]['stats'] = {}
+ stats_table[component_name][
+ 'stats'][iteration] = [value_mean, value_stddev,
+ deriv_mean, deriv_stddev,
+ value_5th, value_50th, value_95th,
+ deriv_5th, deriv_50th, deriv_95th]
+
def parse_progress_logs_for_nonlinearity_stats(exp_dir):
- """ Parse progress logs for mean and std stats for non-linearities.
+ """ Parse progress logs for mean and std stats for non-linearities.
e.g. for a line that is parsed from progress.*.log:
exp/nnet3/lstm_self_repair_ld5_sp/log/progress.9.log:component name=Lstm3_i
type=SigmoidComponent, dim=1280, self-repair-scale=1e-05, count=1.96e+05,
progress_log_files = "%s/log/progress.*.log" % (exp_dir)
stats_per_component_per_iter = {}
- progress_log_lines = common_lib.run_kaldi_command(
- 'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files))[0]
+ progress_log_lines = common_lib.get_command_stdout(
+ 'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files),
+ require_zero_status = False)
+
+ parse_regex = re.compile(g_normal_nonlin_regex_pattern)
- parse_regex = re.compile(
- ".*progress.([0-9]+).log:component name=(.+) "
- "type=(.*)Component,.*"
- "value-avg=\[.*mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*"
- "deriv-avg=\[.*mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]")
for line in progress_log_lines.split("\n"):
mat_obj = parse_regex.search(line)
if mat_obj is None:
continue
- # groups = ('9', 'Lstm3_i', 'Sigmoid', '0.502', '0.23',
- # '0.134', '0.0397')
+ # groups = ('9', 'Lstm3_i', 'Sigmoid', '0.05...0.99', '0.502', '0.23',
+ # '0.009...0.21', '0.134', '0.0397')
groups = mat_obj.groups()
- iteration = int(groups[0])
- component_name = groups[1]
component_type = groups[2]
- value_mean = float(groups[3])
- value_stddev = float(groups[4])
- deriv_mean = float(groups[5])
- deriv_stddev = float(groups[6])
- try:
- stats_per_component_per_iter[component_name][
- 'stats'][iteration] = [value_mean, value_stddev,
- deriv_mean, deriv_stddev]
- except KeyError:
- stats_per_component_per_iter[component_name] = {}
- stats_per_component_per_iter[component_name][
- 'type'] = component_type
- stats_per_component_per_iter[component_name]['stats'] = {}
- stats_per_component_per_iter[component_name][
- 'stats'][iteration] = [value_mean, value_stddev,
- deriv_mean, deriv_stddev]
-
+ if component_type == 'LstmNonlinearity':
+ parse_regex_lstmp = re.compile(g_lstmp_nonlin_regex_pattern)
+ mat_obj = parse_regex_lstmp.search(line)
+ groups = mat_obj.groups()
+ assert len(groups) == 33
+ for i in list(range(0,5)):
+ fill_nonlin_stats_table_with_regex_result(groups, i,
+ stats_per_component_per_iter)
+ else:
+ fill_nonlin_stats_table_with_regex_result(groups, 0,
+ stats_per_component_per_iter)
return stats_per_component_per_iter
progress_log_files = "%s/log/progress.*.log" % (exp_dir)
component_names = set([])
- progress_log_lines = common_lib.run_kaldi_command(
+ progress_log_lines = common_lib.get_command_stdout(
'grep -e "{0}" {1}'.format(
- "clipped-proportion", progress_log_files))[0]
+ "clipped-proportion", progress_log_files),
+ require_zero_status=False)
parse_regex = re.compile(".*progress\.([0-9]+)\.log:component "
"name=(.*) type=.* "
"clipped-proportion=([0-9\.e\-]+)")
progress_log_files = "%s/log/progress.*.log" % (exp_dir)
progress_per_iter = {}
component_names = set([])
- progress_log_lines = common_lib.run_kaldi_command(
- 'grep -e "{0}" {1}'.format(pattern, progress_log_files))[0]
+ progress_log_lines = common_lib.get_command_stdout(
+ 'grep -e "{0}" {1}'.format(pattern, progress_log_files))
parse_regex = re.compile(".*progress\.([0-9]+)\.log:"
"LOG.*{0}.*\[(.*)\]".format(pattern))
for line in progress_log_lines.split("\n"):
'max_iter': max_iter}
-def parse_train_logs(exp_dir):
- train_log_files = "%s/log/train.*.log" % (exp_dir)
- train_log_lines = common_lib.run_kaldi_command(
- 'grep -e Accounting {0}'.format(train_log_files))[0]
+def get_train_times(exp_dir):
+ train_log_files = "%s/log/" % (exp_dir)
+ train_log_names = "train.*.log"
+ train_log_lines = common_lib.get_command_stdout(
+ 'find {0} -name "{1}" | xargs grep -H -e Accounting'.format(train_log_files,train_log_names))
parse_regex = re.compile(".*train\.([0-9]+)\.([0-9]+)\.log:# "
"Accounting: time=([0-9]+) thread.*")
def parse_prob_logs(exp_dir, key='accuracy', output="output"):
train_prob_files = "%s/log/compute_prob_train.*.log" % (exp_dir)
valid_prob_files = "%s/log/compute_prob_valid.*.log" % (exp_dir)
- train_prob_strings = common_lib.run_kaldi_command(
- 'grep -e {0} {1}'.format(key, train_prob_files), wait=True)[0]
- valid_prob_strings = common_lib.run_kaldi_command(
- 'grep -e {0} {1}'.format(key, valid_prob_files))[0]
+ train_prob_strings = common_lib.get_command_stdout(
+ 'grep -e {0} {1}'.format(key, train_prob_files))
+ valid_prob_strings = common_lib.get_command_stdout(
+ 'grep -e {0} {1}'.format(key, valid_prob_files))
# LOG
# (nnet3-chain-compute-prob:PrintTotalStats():nnet-chain-diagnostics.cc:149)
parse_regex = re.compile(
".*compute_prob_.*\.([0-9]+).log:LOG "
- ".nnet3.*compute-prob:PrintTotalStats..:"
+ ".nnet3.*compute-prob.*:PrintTotalStats..:"
"nnet.*diagnostics.cc:[0-9]+. Overall ([a-zA-Z\-]+) for "
"'{output}'.*is ([0-9.\-e]+) .*per frame".format(output=output))
groups = mat_obj.groups()
if groups[1] == key:
train_loss[int(groups[0])] = groups[2]
+ if not train_loss:
+ raise KaldiLogParseException("Could not find any lines with {k} in "
+ " {l}".format(k=key, l=train_prob_files))
+
for line in valid_prob_strings.split('\n'):
mat_obj = parse_regex.search(line)
if mat_obj is not None:
groups = mat_obj.groups()
if groups[1] == key:
valid_loss[int(groups[0])] = groups[2]
+
+ if not valid_loss:
+ raise KaldiLogParseException("Could not find any lines with {k} in "
+ " {l}".format(k=key, l=valid_prob_files))
+
iters = list(set(valid_loss.keys()).intersection(train_loss.keys()))
+ if not iters:
+ raise KaldiLogParseException("Could not any common iterations with"
+ " key {k} in both {tl} and {vl}".format(
+ k=key, tl=train_prob_files, vl=valid_prob_files))
iters.sort()
return map(lambda x: (int(x), float(train_loss[x]),
float(valid_loss[x])), iters)
-def generate_accuracy_report(exp_dir, key="accuracy", output="output"):
- times = parse_train_logs(exp_dir)
- data = parse_prob_logs(exp_dir, key, output)
+
+def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
+ try:
+ times = get_train_times(exp_dir)
+ except:
+ tb = traceback.format_exc()
+ logger.warning("Error getting info from logs, exception was: " + tb)
+ times = []
+
report = []
report.append("%Iter\tduration\ttrain_loss\tvalid_loss\tdifference")
+ try:
+ data = list(parse_prob_logs(exp_dir, key, output))
+ except:
+ tb = traceback.format_exc()
+ logger.warning("Error getting info from logs, exception was: " + tb)
+ data = []
for x in data:
try:
report.append("%d\t%s\t%g\t%g\t%g" % (x[0], str(times[x[0]]),