]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - common_plot.py
Added support for ploting epochs instead of iterations
[jacinto-ai/caffe-jacinto.git] / common_plot.py
1 import re
2 import os
3 import matplotlib.pyplot as plt
4 import matplotlib.cm as cm
5 import numpy as np
6 import scipy.spatial as spatial
9 def get_test_accuracy(log, top_k):
10     iteration = re.findall(r'Iteration (\d*), Testing net \(#0\)', log)
11     accuracy = re.findall(r'Test net output #\d: accuracy/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
12     if len(accuracy)==0:
13         accuracy = re.findall(r'Test net output #\d: top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
14     if len(accuracy)==0:
15         accuracy = re.findall(r'Test net output #\d: loss/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
16     if len(accuracy)==0:
17         accuracy = re.findall(r'Test net output #\d: accuracy/top{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
18     if len(accuracy)==0:
19         accuracy = re.findall(r'Test net output #\d: accuracy = (\d*.\d*)', log)
20     iteration = [int(i) for i in iteration]
21     accuracy = [float(i) for i in accuracy]
22     return iteration, accuracy
25 def get_test_loss(log):
26     iteration = re.findall(r'Iteration (\d*), Testing net ', log)
27     loss = re.findall(r'Test net output #\d: loss = (\d*.\d*)', log)
28     if len(loss)==0:
29         loss = re.findall(r'Test net output #\d: loss/loss = (\d*.\d*)', log)
30     iteration = [int(i) for i in iteration]
31     loss = [float(i) for i in loss]
32     return iteration, loss
34 def get_train_loss(log):
35     iteration = re.findall(r'Iteration (\d*), lr = ', log)
36     loss = re.findall(r'Train net output #\d: loss = (\d*.\d*)', log)
37     iteration = [int(i) for i in iteration]
38     loss = [float(i) for i in loss]
39     return iteration, loss
41 def get_epochs(log):
42     num_gpus=8
43     max_iter = re.findall(r'max_iter: (\d*)', log)
44     iter_size = re.findall(r'iter_size: (\d*)', log)
45     batch_size = re.findall(r'batch_size: (\d*)',log)
46     max_iter = int(max_iter[0])
47     if len(iter_size) >0:
48         iter_size=int(iter_size[0])
49     else:
50         iter_size=1
52     batch_size = int(batch_size[0])
53    # print max_iter, iter_size, batch_size
54     num_epochs = int(round( (max_iter * iter_size * batch_size*num_gpus) /  1281167. +0.5))
55     return max_iter, num_epochs
57 def get_net_name(log):
58     return re.findall(r"Solving (.*)\n", log)[0]
61 def parse_files(files, top_k=1, separate=False):
62     data = {}
63     for file in files:
64         with open(file, 'r') as fp:
65             log = fp.read()
66             net_name = os.path.basename(file) if separate else get_net_name(log)
67             if net_name not in data.keys():
68                 data[net_name] = {}
69                 data[net_name]["accuracy"] = {}
70                 data[net_name]["accuracy"]["accuracy"] = []
71                 data[net_name]["accuracy"]["iteration"] = []
72                 data[net_name]["loss"] = {}
73                 data[net_name]["loss"]["loss"] = []
74                 data[net_name]["loss"]["iteration"] = []
75                 data[net_name]["train_loss"] = {}
76                 data[net_name]["train_loss"]["loss"] = []
77                 data[net_name]["train_loss"]["iteration"] = []
79             max_iter, epochs = get_epochs(log)
80             #print epochs
81             scale = float(epochs) / max_iter
83             iteration, accuracy = get_test_accuracy(log, top_k)
84             iteration = [k*scale for k in iteration]
85             data[net_name]["accuracy"]["iteration"].extend(iteration)
86             data[net_name]["accuracy"]["accuracy"].extend(accuracy)
88             iteration, loss = get_test_loss(log)
89             iteration = [k*scale for k in iteration]
90             data[net_name]["loss"]["iteration"].extend(iteration)
91             data[net_name]["loss"]["loss"].extend(loss)
94             iteration, loss = get_train_loss(log)
95             iteration = [k*scale for k in iteration]
96             data[net_name]["train_loss"]["iteration"].extend(iteration)
97             data[net_name]["train_loss"]["loss"].extend(loss)
99     return data
102 def fmt(x, y):
103     return 'x: {x:0.2f}\ny: {y:0.2f}'.format(x=x, y=y)
106 class FollowDotCursor(object):
107     """Display the x,y location of the nearest data point.
108     http://stackoverflow.com/a/4674445/190597 (Joe Kington)
109     http://stackoverflow.com/a/20637433/190597 (unutbu)
110     """
111     def __init__(self, ax, x, y, formatter=fmt, offsets=(-20, 20)):
112         try:
113             x = np.asarray(x, dtype='float')
114         except (TypeError, ValueError):
115             x = np.asarray(mdates.date2num(x), dtype='float')
116         y = np.asarray(y, dtype='float')
117         mask = ~(np.isnan(x) | np.isnan(y))
118         x = x[mask]
119         y = y[mask]
120         self._points = np.column_stack((x, y))
121         self.offsets = offsets
122         y = y[np.abs(y - y.mean()) <= 3 * y.std()]
123         self.scale = x.ptp()
124         self.scale = y.ptp() / self.scale if self.scale else 1
125         self.tree = spatial.cKDTree(self.scaled(self._points))
126         self.formatter = formatter
127         self.ax = ax
128         self.fig = ax.figure
129         self.ax.xaxis.set_label_position('top')
130         self.dot = ax.scatter(
131             [x.min()], [y.min()], s=130, color='green', alpha=0.7)
132         self.annotation = self.setup_annotation()
133         plt.connect('motion_notify_event', self)
135     def scaled(self, points):
136         points = np.asarray(points)
137         return points * (self.scale, 1)
139     def __call__(self, event):
140         ax = self.ax
141         # event.inaxes is always the current axis. If you use twinx, ax could be
142         # a different axis.
143         if event.inaxes == ax:
144             x, y = event.xdata, event.ydata
145         elif event.inaxes is None:
146             return
147         else:
148             inv = ax.transData.inverted()
149             x, y = inv.transform([(event.x, event.y)]).ravel()
150         annotation = self.annotation
151         x, y = self.snap(x, y)
152         annotation.xy = x, y
153         annotation.set_text(self.formatter(x, y))
154         self.dot.set_offsets((x, y))
155         event.canvas.draw()
157     def setup_annotation(self):
158         """Draw and hide the annotation box."""
159         annotation = self.ax.annotate(
160             '', xy=(0, 0), ha = 'right',
161             xytext = self.offsets, textcoords = 'offset points', va = 'bottom',
162             bbox = dict(
163                 boxstyle='round,pad=0.5', fc='yellow', alpha=0.75),
164             arrowprops = dict(
165                 arrowstyle='->', connectionstyle='arc3,rad=0'))
166         return annotation
168     def snap(self, x, y):
169         """Return the value in self.tree closest to x, y."""
170         dist, idx = self.tree.query(self.scaled((x, y)), k=1, p=1)
171         try:
172             return self._points[idx]
173         except IndexError:
174             # IndexError: index out of bounds
175             return self._points[0]
178 def plot_accuracy(top_k, data, value_at_hover=False):
179     nets =  data.keys()
180     colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
181     fig = plt.figure()
182     ax = fig.add_subplot(111)
183     for net in nets:
184         iteration = data[net]["accuracy"]["iteration"]
185         accuracy = data[net]["accuracy"]["accuracy"]
186         iteration, accuracy = (np.array(t) for t in zip(*sorted(zip(iteration, accuracy))))
187         ax.plot(iteration, accuracy*100, color=next(colors), linestyle='-')
188         if value_at_hover:
189             cursor = FollowDotCursor(ax, iteration, accuracy*100)
191     plt.legend(nets, loc='lower right')
192     plt.title("Top {}".format(top_k))
193     plt.xlabel("Epochs")
194     plt.ylabel("Accuracy [%]")
195     plt.ylim(0,100)
196     plt.grid()
197     return plt
200 def plot_loss(data, value_at_hover=False):
201     nets =  data.keys()
202     colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
203     fig = plt.figure()
204     ax = fig.add_subplot(111)
205     for net in nets:
206         iteration = data[net]["loss"]["iteration"]
207         loss = data[net]["loss"]["loss"]
208         iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
209         ax.scatter(iteration, loss, color=next(colors))
210         if value_at_hover:
211             cursor = FollowDotCursor(ax, iteration, loss)
213     plt.legend(nets, loc='upper right')
214     plt.title("Log Loss")
215     plt.xlabel("Iteration")
216     plt.ylabel("Log Loss")
217     plt.xlim(0)
218     plt.grid()
219     return plt
221 def plot_train_loss(data, value_at_hover=False):
222     nets =  data.keys()
223     colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
224     fig = plt.figure()
225     ax = fig.add_subplot(111)
226     for net in nets:
227         iteration = data[net]["train_loss"]["iteration"]
228         loss = data[net]["train_loss"]["loss"]
229         iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
230         ax.scatter(iteration, loss, color=next(colors))
231         if value_at_hover:
232             cursor = FollowDotCursor(ax, iteration, loss)
234     plt.legend(nets, loc='upper right')
235     plt.title("Log Loss")
236     plt.xlabel("Iteration")
237     plt.ylabel("Log Loss")
238     plt.xlim(0)
239     plt.grid()
240     return plt