[scripts] Minor fix to nnet3 training scripts RE log parsing/time-reporting (#1926)
[processor-sdk/kaldi.git] / egs / wsj / s5 / steps / libs / nnet3 / report / log_parse.py
3 # Copyright 2016    Vijayaditya Peddinti
4 #                   Vimal Manohar
5 # Apache 2.0.
7 from __future__ import division
8 from __future__ import print_function
9 import traceback
10 import datetime
11 import logging
12 import re
14 import libs.common as common_lib
16 logger = logging.getLogger(__name__)
17 logger.addHandler(logging.NullHandler())
19 g_lstmp_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
20     "type=(.*)Component,.*",
21     "i_t_sigmoid.*",
22     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
23     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
24     "f_t_sigmoid.*",
25     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
26     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
27     "c_t_tanh.*",
28     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
29     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
30     "o_t_sigmoid.*",
31     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
32     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
33     "m_t_tanh.*",
34     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
35     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])
38 g_normal_nonlin_regex_pattern = ''.join([".*progress.([0-9]+).log:component name=(.+) ",
39     "type=(.*)Component,.*",
40     "value-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\].*",
41     "deriv-avg=\[.*=\((.+)\), mean=([0-9\.\-e]+), stddev=([0-9\.e\-]+)\]"])
43 class KaldiLogParseException(Exception):
44     """ An Exception class that throws an error when there is an issue in
45     parsing the log files. Extend this class if more granularity is needed.
46     """
47     def __init__(self, message = None):
48         if message is not None and message.strip() == "":
49             message = None
51         Exception.__init__(self,
52                            "There was an error while trying to parse the logs."
53                            " Details : \n{0}\n".format(message))
55 # This function is used to fill stats_per_component_per_iter table with the
56 # results of regular expression.
57 def fill_nonlin_stats_table_with_regex_result(groups, gate_index, stats_table):
58     iteration = int(groups[0])
59     component_name = groups[1]
60     component_type = groups[2]
61     value_percentiles = groups[3+gate_index*6]
62     value_mean = float(groups[4+gate_index*6])
63     value_stddev = float(groups[5+gate_index*6])
64     value_percentiles_split = re.split(',| ',value_percentiles)
65     assert len(value_percentiles_split) == 13
66     value_5th = float(value_percentiles_split[4])
67     value_50th = float(value_percentiles_split[6])
68     value_95th = float(value_percentiles_split[9])
69     deriv_percentiles = groups[6+gate_index*6]
70     deriv_mean = float(groups[7+gate_index*6])
71     deriv_stddev = float(groups[8+gate_index*6])
72     deriv_percentiles_split = re.split(',| ',deriv_percentiles)
73     assert len(deriv_percentiles_split) == 13
74     deriv_5th = float(deriv_percentiles_split[4])
75     deriv_50th = float(deriv_percentiles_split[6])
76     deriv_95th = float(deriv_percentiles_split[9])
77     try:
78         if stats_table[component_name]['stats'].has_key(iteration):
79             stats_table[component_name]['stats'][iteration].extend(
80                     [value_mean, value_stddev,
81                      deriv_mean, deriv_stddev,
82                      value_5th, value_50th, value_95th,
83                      deriv_5th, deriv_50th, deriv_95th])
84         else:
85             stats_table[component_name]['stats'][iteration] = [
86                     value_mean, value_stddev,
87                     deriv_mean, deriv_stddev,
88                     value_5th, value_50th, value_95th,
89                     deriv_5th, deriv_50th, deriv_95th]
90     except KeyError:
91         stats_table[component_name] = {}
92         stats_table[component_name]['type'] = component_type
93         stats_table[component_name]['stats'] = {}
94         stats_table[component_name][
95                 'stats'][iteration] = [value_mean, value_stddev,
96                                        deriv_mean, deriv_stddev,
97                                        value_5th, value_50th, value_95th,
98                                        deriv_5th, deriv_50th, deriv_95th]
101 def parse_progress_logs_for_nonlinearity_stats(exp_dir):
103     """ Parse progress logs for mean and std stats for non-linearities.
104     e.g. for a line that is parsed from progress.*.log:
105     exp/nnet3/lstm_self_repair_ld5_sp/log/progress.9.log:component name=Lstm3_i
106     type=SigmoidComponent, dim=1280, self-repair-scale=1e-05, count=1.96e+05,
107     value-avg=[percentiles(0,1,2,5 10,20,50,80,90
108     95,98,99,100)=(0.05,0.09,0.11,0.15 0.19,0.27,0.50,0.72,0.83
109     0.88,0.92,0.94,0.99), mean=0.502, stddev=0.23],
110     deriv-avg=[percentiles(0,1,2,5 10,20,50,80,90
111     95,98,99,100)=(0.009,0.04,0.05,0.06 0.08,0.10,0.14,0.17,0.18
112     0.19,0.20,0.20,0.21), mean=0.134, stddev=0.0397]
113     """
115     progress_log_files = "%s/log/progress.*.log" % (exp_dir)
116     stats_per_component_per_iter = {}
118     progress_log_lines = common_lib.get_command_stdout(
119         'grep -e "value-avg.*deriv-avg" {0}'.format(progress_log_files),
120         require_zero_status = False)
122     parse_regex = re.compile(g_normal_nonlin_regex_pattern)
125     for line in progress_log_lines.split("\n"):
126         mat_obj = parse_regex.search(line)
127         if mat_obj is None:
128             continue
129         # groups = ('9', 'Lstm3_i', 'Sigmoid', '0.05...0.99', '0.502', '0.23',
130         # '0.009...0.21', '0.134', '0.0397')
131         groups = mat_obj.groups()
132         component_type = groups[2]
133         if component_type == 'LstmNonlinearity':
134             parse_regex_lstmp = re.compile(g_lstmp_nonlin_regex_pattern)
135             mat_obj = parse_regex_lstmp.search(line)
136             groups = mat_obj.groups()
137             assert len(groups) == 33
138             for i in list(range(0,5)):
139                 fill_nonlin_stats_table_with_regex_result(groups, i,
140                         stats_per_component_per_iter)
141         else:
142             fill_nonlin_stats_table_with_regex_result(groups, 0,
143                     stats_per_component_per_iter)
144     return stats_per_component_per_iter
147 def parse_difference_string(string):
148     dict = {}
149     for parts in string.split():
150         sub_parts = parts.split(":")
151         dict[sub_parts[0]] = float(sub_parts[1])
152     return dict
155 class MalformedClippedProportionLineException(Exception):
156     def __init__(self, line):
157         Exception.__init__(self,
158                            "Malformed line encountered while trying to "
159                            "extract clipped-proportions.\n{0}".format(line))
162 def parse_progress_logs_for_clipped_proportion(exp_dir):
163     """ Parse progress logs for clipped proportion stats.
165     e.g. for a line that is parsed from progress.*.log:
166     exp/chain/cwrnn_trial2_ld5_sp/log/progress.245.log:component
167     name=BLstm1_forward_c type=ClipGradientComponent, dim=512,
168     norm-based-clipping=true, clipping-threshold=30,
169     clipped-proportion=0.000565527,
170     self-repair-clipped-proportion-threshold=0.01, self-repair-target=0,
171     self-repair-scale=1
172     """
174     progress_log_files = "%s/log/progress.*.log" % (exp_dir)
175     component_names = set([])
176     progress_log_lines = common_lib.get_command_stdout(
177         'grep -e "{0}" {1}'.format(
178             "clipped-proportion", progress_log_files),
179         require_zero_status=False)
180     parse_regex = re.compile(".*progress\.([0-9]+)\.log:component "
181                              "name=(.*) type=.* "
182                              "clipped-proportion=([0-9\.e\-]+)")
184     cp_per_component_per_iter = {}
186     max_iteration = 0
187     component_names = set([])
188     for line in progress_log_lines.split("\n"):
189         mat_obj = parse_regex.search(line)
190         if mat_obj is None:
191             if line.strip() == "":
192                 continue
193             raise MalformedClippedProportionLineException(line)
194         groups = mat_obj.groups()
195         iteration = int(groups[0])
196         max_iteration = max(max_iteration, iteration)
197         name = groups[1]
198         clipped_proportion = float(groups[2])
199         if clipped_proportion > 1:
200             raise MalformedClippedProportionLineException(line)
201         if iteration not in cp_per_component_per_iter:
202             cp_per_component_per_iter[iteration] = {}
203         cp_per_component_per_iter[iteration][name] = clipped_proportion
204         component_names.add(name)
205     component_names = list(component_names)
206     component_names.sort()
208     # re arranging the data into an array
209     # and into an cp_per_iter_per_component
210     cp_per_iter_per_component = {}
211     for component_name in component_names:
212         cp_per_iter_per_component[component_name] = []
213     data = []
214     data.append(["iteration"]+component_names)
215     for iter in range(max_iteration+1):
216         if iter not in cp_per_component_per_iter:
217             continue
218         comp_dict = cp_per_component_per_iter[iter]
219         row = [iter]
220         for component in component_names:
221             try:
222                 row.append(comp_dict[component])
223                 cp_per_iter_per_component[component].append(
224                     [iter, comp_dict[component]])
225             except KeyError:
226                 # if clipped proportion is not available for a particular
227                 # component it is set to None
228                 # this usually happens during layer-wise discriminative
229                 # training
230                 row.append(None)
231         data.append(row)
233     return {'table': data,
234             'cp_per_component_per_iter': cp_per_component_per_iter,
235             'cp_per_iter_per_component': cp_per_iter_per_component}
238 def parse_progress_logs_for_param_diff(exp_dir, pattern):
239     """ Parse progress logs for per-component parameter differences.
241     e.g. for a line that is parsed from progress.*.log:
242     exp/chain/cwrnn_trial2_ld5_sp/log/progress.245.log:LOG
243     (nnet3-show-progress:main():nnet3-show-progress.cc:144) Relative parameter
244     differences per layer are [ Cwrnn1_T3_W_r:0.0171537
245     Cwrnn1_T3_W_x:1.33338e-07 Cwrnn1_T2_W_r:0.048075 Cwrnn1_T2_W_x:1.34088e-07
246     Cwrnn1_T1_W_r:0.0157277 Cwrnn1_T1_W_x:0.0212704 Final_affine:0.0321521
247     Cwrnn2_T3_W_r:0.0212082 Cwrnn2_T3_W_x:1.33691e-07 Cwrnn2_T2_W_r:0.0212978
248     Cwrnn2_T2_W_x:1.33401e-07 Cwrnn2_T1_W_r:0.014976 Cwrnn2_T1_W_x:0.0233588
249     Cwrnn3_T3_W_r:0.0237165 Cwrnn3_T3_W_x:1.33184e-07 Cwrnn3_T2_W_r:0.0239754
250     Cwrnn3_T2_W_x:1.3296e-07 Cwrnn3_T1_W_r:0.0194809 Cwrnn3_T1_W_x:0.0271934 ]
251     """
253     if pattern not in set(["Relative parameter differences",
254                            "Parameter differences"]):
255         raise Exception("Unknown value for pattern : {0}".format(pattern))
257     progress_log_files = "%s/log/progress.*.log" % (exp_dir)
258     progress_per_iter = {}
259     component_names = set([])
260     progress_log_lines = common_lib.get_command_stdout(
261         'grep -e "{0}" {1}'.format(pattern, progress_log_files))
262     parse_regex = re.compile(".*progress\.([0-9]+)\.log:"
263                              "LOG.*{0}.*\[(.*)\]".format(pattern))
264     for line in progress_log_lines.split("\n"):
265         mat_obj = parse_regex.search(line)
266         if mat_obj is None:
267             continue
268         groups = mat_obj.groups()
269         iteration = groups[0]
270         differences = parse_difference_string(groups[1])
271         component_names = component_names.union(differences.keys())
272         progress_per_iter[int(iteration)] = differences
274     component_names = list(component_names)
275     component_names.sort()
276     # rearranging the parameter differences available per iter
277     # into parameter differences per component
278     progress_per_component = {}
279     for cn in component_names:
280         progress_per_component[cn] = {}
282     max_iter = max(progress_per_iter.keys())
283     total_missing_iterations = 0
284     gave_user_warning = False
285     for iter in range(max_iter + 1):
286         try:
287             component_dict = progress_per_iter[iter]
288         except KeyError:
289             continue
291         for component_name in component_names:
292             try:
293                 progress_per_component[component_name][iter] = component_dict[
294                     component_name]
295             except KeyError:
296                 total_missing_iterations += 1
297                 # the component was not found this iteration, may be because of
298                 # layerwise discriminative training
299                 pass
300         if (total_missing_iterations/len(component_names) > 20
301                 and not gave_user_warning and logger is not None):
302             logger.warning("There are more than {0} missing iterations per "
303                            "component. Something might be wrong.".format(
304                                 total_missing_iterations/len(component_names)))
305             gave_user_warning = True
307     return {'progress_per_component': progress_per_component,
308             'component_names': component_names,
309             'max_iter': max_iter}
312 def get_train_times(exp_dir):
313     train_log_files = "%s/log/" % (exp_dir)
314     train_log_names = "train.*.log"
315     train_log_lines = common_lib.get_command_stdout(
316         'find {0} -name "{1}" | xargs grep -H -e Accounting'.format(train_log_files,train_log_names))
317     parse_regex = re.compile(".*train\.([0-9]+)\.([0-9]+)\.log:# "
318                              "Accounting: time=([0-9]+) thread.*")
320     train_times = {}
321     for line in train_log_lines.split('\n'):
322         mat_obj = parse_regex.search(line)
323         if mat_obj is not None:
324             groups = mat_obj.groups()
325             try:
326                 train_times[int(groups[0])][int(groups[1])] = float(groups[2])
327             except KeyError:
328                 train_times[int(groups[0])] = {}
329                 train_times[int(groups[0])][int(groups[1])] = float(groups[2])
330     iters = train_times.keys()
331     for iter in iters:
332         values = train_times[iter].values()
333         train_times[iter] = max(values)
334     return train_times
337 def parse_prob_logs(exp_dir, key='accuracy', output="output"):
338     train_prob_files = "%s/log/compute_prob_train.*.log" % (exp_dir)
339     valid_prob_files = "%s/log/compute_prob_valid.*.log" % (exp_dir)
340     train_prob_strings = common_lib.get_command_stdout(
341         'grep -e {0} {1}'.format(key, train_prob_files))
342     valid_prob_strings = common_lib.get_command_stdout(
343         'grep -e {0} {1}'.format(key, valid_prob_files))
345     # LOG
346     # (nnet3-chain-compute-prob:PrintTotalStats():nnet-chain-diagnostics.cc:149)
347     # Overall log-probability for 'output' is -0.399395 + -0.013437 = -0.412832
348     # per frame, over 20000 fra
350     # LOG
351     # (nnet3-chain-compute-prob:PrintTotalStats():nnet-chain-diagnostics.cc:144)
352     # Overall log-probability for 'output' is -0.307255 per frame, over 20000
353     # frames.
355     parse_regex = re.compile(
356         ".*compute_prob_.*\.([0-9]+).log:LOG "
357         ".nnet3.*compute-prob.*:PrintTotalStats..:"
358         "nnet.*diagnostics.cc:[0-9]+. Overall ([a-zA-Z\-]+) for "
359         "'{output}'.*is ([0-9.\-e]+) .*per frame".format(output=output))
361     train_loss = {}
362     valid_loss = {}
364     for line in train_prob_strings.split('\n'):
365         mat_obj = parse_regex.search(line)
366         if mat_obj is not None:
367             groups = mat_obj.groups()
368             if groups[1] == key:
369                 train_loss[int(groups[0])] = groups[2]
370     if not train_loss:
371         raise KaldiLogParseException("Could not find any lines with {k} in "
372                 " {l}".format(k=key, l=train_prob_files))
374     for line in valid_prob_strings.split('\n'):
375         mat_obj = parse_regex.search(line)
376         if mat_obj is not None:
377             groups = mat_obj.groups()
378             if groups[1] == key:
379                 valid_loss[int(groups[0])] = groups[2]
381     if not valid_loss:
382         raise KaldiLogParseException("Could not find any lines with {k} in "
383                 " {l}".format(k=key, l=valid_prob_files))
385     iters = list(set(valid_loss.keys()).intersection(train_loss.keys()))
386     if not iters:
387         raise KaldiLogParseException("Could not any common iterations with"
388                 " key {k} in both {tl} and {vl}".format(
389                     k=key, tl=train_prob_files, vl=valid_prob_files))
390     iters.sort()
391     return map(lambda x: (int(x), float(train_loss[x]),
392                           float(valid_loss[x])), iters)
396 def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
397     try:
398         times = get_train_times(exp_dir)
399     except:
400         tb = traceback.format_exc()
401         logger.warning("Error getting info from logs, exception was: " + tb)
402         times = []
404     report = []
405     report.append("%Iter\tduration\ttrain_loss\tvalid_loss\tdifference")
406     try:
407         data = list(parse_prob_logs(exp_dir, key, output))
408     except:
409         tb = traceback.format_exc()
410         logger.warning("Error getting info from logs, exception was: " + tb)
411         data = []
412     for x in data:
413         try:
414             report.append("%d\t%s\t%g\t%g\t%g" % (x[0], str(times[x[0]]),
415                                                   x[1], x[2], x[2]-x[1]))
416         except KeyError:
417             continue
419     total_time = 0
420     for iter in times.keys():
421         total_time += times[iter]
422     report.append("Total training time is {0}\n".format(
423                     str(datetime.timedelta(seconds=total_time))))
424     return ["\n".join(report), times, data]