1 import re
2 import os
3 import matplotlib.pyplot as plt
4 import matplotlib.cm as cm
5 import numpy as np
6 import scipy.spatial as spatial
9 def get_test_accuracy(log, top_k):
10 iteration = re.findall(r'Iteration (\d*), Testing net \(#0\)', log)
11 accuracy = re.findall(r'Test net output #\d: accuracy/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
12 if len(accuracy)==0:
13 accuracy = re.findall(r'Test net output #\d: top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
14 if len(accuracy)==0:
15 accuracy = re.findall(r'Test net output #\d: loss/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
16 if len(accuracy)==0:
17 accuracy = re.findall(r'Test net output #\d: accuracy = (\d*.\d*)', log)
18 iteration = [int(i) for i in iteration]
19 accuracy = [float(i) for i in accuracy]
20 return iteration, accuracy
23 def get_test_loss(log):
24 iteration = re.findall(r'Iteration (\d*), Testing net ', log)
25 loss = re.findall(r'Test net output #\d: loss = (\d*.\d*)', log)
26 if len(loss)==0:
27 loss = re.findall(r'Test net output #\d: loss/loss = (\d*.\d*)', log)
28 iteration = [int(i) for i in iteration]
29 loss = [float(i) for i in loss]
30 return iteration, loss
32 def get_train_loss(log):
33 iteration = re.findall(r'Iteration (\d*), lr = ', log)
34 loss = re.findall(r'Train net output #\d: loss = (\d*.\d*)', log)
35 iteration = [int(i) for i in iteration]
36 loss = [float(i) for i in loss]
37 return iteration, loss
40 def get_net_name(log):
41 return re.findall(r"Solving (.*)\n", log)[0]
44 def parse_files(files, top_k=1, separate=False):
45 data = {}
46 for file in files:
47 with open(file, 'r') as fp:
48 log = fp.read()
49 net_name = os.path.basename(file) if separate else get_net_name(log)
50 if net_name not in data.keys():
51 data[net_name] = {}
52 data[net_name]["accuracy"] = {}
53 data[net_name]["accuracy"]["accuracy"] = []
54 data[net_name]["accuracy"]["iteration"] = []
55 data[net_name]["loss"] = {}
56 data[net_name]["loss"]["loss"] = []
57 data[net_name]["loss"]["iteration"] = []
58 data[net_name]["train_loss"] = {}
59 data[net_name]["train_loss"]["loss"] = []
60 data[net_name]["train_loss"]["iteration"] = []
62 iteration, accuracy = get_test_accuracy(log, top_k)
63 data[net_name]["accuracy"]["iteration"].extend(iteration)
64 data[net_name]["accuracy"]["accuracy"].extend(accuracy)
66 iteration, loss = get_test_loss(log)
67 data[net_name]["loss"]["iteration"].extend(iteration)
68 data[net_name]["loss"]["loss"].extend(loss)
70 iteration, loss = get_train_loss(log)
71 data[net_name]["train_loss"]["iteration"].extend(iteration)
72 data[net_name]["train_loss"]["loss"].extend(loss)
74 return data
77 def fmt(x, y):
78 return 'x: {x:0.2f}\ny: {y:0.2f}'.format(x=x, y=y)
81 class FollowDotCursor(object):
82 """Display the x,y location of the nearest data point.
83 http://stackoverflow.com/a/4674445/190597 (Joe Kington)
84 http://stackoverflow.com/a/20637433/190597 (unutbu)
85 """
86 def __init__(self, ax, x, y, formatter=fmt, offsets=(-20, 20)):
87 try:
88 x = np.asarray(x, dtype='float')
89 except (TypeError, ValueError):
90 x = np.asarray(mdates.date2num(x), dtype='float')
91 y = np.asarray(y, dtype='float')
92 mask = ~(np.isnan(x) | np.isnan(y))
93 x = x[mask]
94 y = y[mask]
95 self._points = np.column_stack((x, y))
96 self.offsets = offsets
97 y = y[np.abs(y - y.mean()) <= 3 * y.std()]
98 self.scale = x.ptp()
99 self.scale = y.ptp() / self.scale if self.scale else 1
100 self.tree = spatial.cKDTree(self.scaled(self._points))
101 self.formatter = formatter
102 self.ax = ax
103 self.fig = ax.figure
104 self.ax.xaxis.set_label_position('top')
105 self.dot = ax.scatter(
106 [x.min()], [y.min()], s=130, color='green', alpha=0.7)
107 self.annotation = self.setup_annotation()
108 plt.connect('motion_notify_event', self)
110 def scaled(self, points):
111 points = np.asarray(points)
112 return points * (self.scale, 1)
114 def __call__(self, event):
115 ax = self.ax
116 # event.inaxes is always the current axis. If you use twinx, ax could be
117 # a different axis.
118 if event.inaxes == ax:
119 x, y = event.xdata, event.ydata
120 elif event.inaxes is None:
121 return
122 else:
123 inv = ax.transData.inverted()
124 x, y = inv.transform([(event.x, event.y)]).ravel()
125 annotation = self.annotation
126 x, y = self.snap(x, y)
127 annotation.xy = x, y
128 annotation.set_text(self.formatter(x, y))
129 self.dot.set_offsets((x, y))
130 event.canvas.draw()
132 def setup_annotation(self):
133 """Draw and hide the annotation box."""
134 annotation = self.ax.annotate(
135 '', xy=(0, 0), ha = 'right',
136 xytext = self.offsets, textcoords = 'offset points', va = 'bottom',
137 bbox = dict(
138 boxstyle='round,pad=0.5', fc='yellow', alpha=0.75),
139 arrowprops = dict(
140 arrowstyle='->', connectionstyle='arc3,rad=0'))
141 return annotation
143 def snap(self, x, y):
144 """Return the value in self.tree closest to x, y."""
145 dist, idx = self.tree.query(self.scaled((x, y)), k=1, p=1)
146 try:
147 return self._points[idx]
148 except IndexError:
149 # IndexError: index out of bounds
150 return self._points[0]
153 def plot_accuracy(top_k, data, value_at_hover=False):
154 nets = data.keys()
155 colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
156 fig = plt.figure()
157 ax = fig.add_subplot(111)
158 for net in nets:
159 iteration = data[net]["accuracy"]["iteration"]
160 accuracy = data[net]["accuracy"]["accuracy"]
161 iteration, accuracy = (np.array(t) for t in zip(*sorted(zip(iteration, accuracy))))
162 ax.plot(iteration, accuracy*100, color=next(colors), linestyle='-')
163 if value_at_hover:
164 cursor = FollowDotCursor(ax, iteration, accuracy*100)
166 plt.legend(nets, loc='lower right')
167 plt.title("Top {}".format(top_k))
168 plt.xlabel("Iteration")
169 plt.ylabel("Accuracy [%]")
170 plt.ylim(0,100)
171 plt.grid()
172 return plt
175 def plot_loss(data, value_at_hover=False):
176 nets = data.keys()
177 colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
178 fig = plt.figure()
179 ax = fig.add_subplot(111)
180 for net in nets:
181 iteration = data[net]["loss"]["iteration"]
182 loss = data[net]["loss"]["loss"]
183 iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
184 ax.scatter(iteration, loss, color=next(colors))
185 if value_at_hover:
186 cursor = FollowDotCursor(ax, iteration, loss)
188 plt.legend(nets, loc='upper right')
189 plt.title("Log Loss")
190 plt.xlabel("Iteration")
191 plt.ylabel("Log Loss")
192 plt.xlim(0)
193 plt.grid()
194 return plt
196 def plot_train_loss(data, value_at_hover=False):
197 nets = data.keys()
198 colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
199 fig = plt.figure()
200 ax = fig.add_subplot(111)
201 for net in nets:
202 iteration = data[net]["train_loss"]["iteration"]
203 loss = data[net]["train_loss"]["loss"]
204 iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
205 ax.scatter(iteration, loss, color=next(colors))
206 if value_at_hover:
207 cursor = FollowDotCursor(ax, iteration, loss)
209 plt.legend(nets, loc='upper right')
210 plt.title("Log Loss")
211 plt.xlabel("Iteration")
212 plt.ylabel("Log Loss")
213 plt.xlim(0)
214 plt.grid()
215 return plt