1 import re
2 import os
3 import matplotlib.pyplot as plt
4 import matplotlib.cm as cm
5 import numpy as np
6 import scipy.spatial as spatial
9 def get_test_accuracy(log, top_k):
10 iteration = re.findall(r'Iteration (\d*), Testing net \(#0\)', log)
11 accuracy = re.findall(r'Test net output #\d: accuracy/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
12 if len(accuracy)==0:
13 accuracy = re.findall(r'Test net output #\d: top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
14 if len(accuracy)==0:
15 accuracy = re.findall(r'Test net output #\d: loss/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
16 iteration = [int(i) for i in iteration]
17 accuracy = [float(i) for i in accuracy]
18 return iteration, accuracy
21 def get_test_loss(log):
22 iteration = re.findall(r'Iteration (\d*), Testing net ', log)
23 loss = re.findall(r'Test net output #\d: loss = (\d*.\d*)', log)
24 if len(loss)==0:
25 loss = re.findall(r'Test net output #\d: loss/loss = (\d*.\d*)', log)
26 iteration = [int(i) for i in iteration]
27 loss = [float(i) for i in loss]
28 return iteration, loss
31 def get_net_name(log):
32 return re.findall(r"Solving (.*)\n", log)[0]
35 def parse_files(files, top_k=1, separate=False):
36 data = {}
37 for file in files:
38 with open(file, 'r') as fp:
39 log = fp.read()
40 net_name = os.path.basename(file) if separate else get_net_name(log)
41 if net_name not in data.keys():
42 data[net_name] = {}
43 data[net_name]["accuracy"] = {}
44 data[net_name]["accuracy"]["accuracy"] = []
45 data[net_name]["accuracy"]["iteration"] = []
46 data[net_name]["loss"] = {}
47 data[net_name]["loss"]["loss"] = []
48 data[net_name]["loss"]["iteration"] = []
49 iteration, accuracy = get_test_accuracy(log, top_k)
50 data[net_name]["accuracy"]["iteration"].extend(iteration)
51 data[net_name]["accuracy"]["accuracy"].extend(accuracy)
53 iteration, loss = get_test_loss(log)
54 data[net_name]["loss"]["iteration"].extend(iteration)
55 data[net_name]["loss"]["loss"].extend(loss)
56 return data
59 def fmt(x, y):
60 return 'x: {x:0.2f}\ny: {y:0.2f}'.format(x=x, y=y)
63 class FollowDotCursor(object):
64 """Display the x,y location of the nearest data point.
65 http://stackoverflow.com/a/4674445/190597 (Joe Kington)
66 http://stackoverflow.com/a/20637433/190597 (unutbu)
67 """
68 def __init__(self, ax, x, y, formatter=fmt, offsets=(-20, 20)):
69 try:
70 x = np.asarray(x, dtype='float')
71 except (TypeError, ValueError):
72 x = np.asarray(mdates.date2num(x), dtype='float')
73 y = np.asarray(y, dtype='float')
74 mask = ~(np.isnan(x) | np.isnan(y))
75 x = x[mask]
76 y = y[mask]
77 self._points = np.column_stack((x, y))
78 self.offsets = offsets
79 y = y[np.abs(y - y.mean()) <= 3 * y.std()]
80 self.scale = x.ptp()
81 self.scale = y.ptp() / self.scale if self.scale else 1
82 self.tree = spatial.cKDTree(self.scaled(self._points))
83 self.formatter = formatter
84 self.ax = ax
85 self.fig = ax.figure
86 self.ax.xaxis.set_label_position('top')
87 self.dot = ax.scatter(
88 [x.min()], [y.min()], s=130, color='green', alpha=0.7)
89 self.annotation = self.setup_annotation()
90 plt.connect('motion_notify_event', self)
92 def scaled(self, points):
93 points = np.asarray(points)
94 return points * (self.scale, 1)
96 def __call__(self, event):
97 ax = self.ax
98 # event.inaxes is always the current axis. If you use twinx, ax could be
99 # a different axis.
100 if event.inaxes == ax:
101 x, y = event.xdata, event.ydata
102 elif event.inaxes is None:
103 return
104 else:
105 inv = ax.transData.inverted()
106 x, y = inv.transform([(event.x, event.y)]).ravel()
107 annotation = self.annotation
108 x, y = self.snap(x, y)
109 annotation.xy = x, y
110 annotation.set_text(self.formatter(x, y))
111 self.dot.set_offsets((x, y))
112 event.canvas.draw()
114 def setup_annotation(self):
115 """Draw and hide the annotation box."""
116 annotation = self.ax.annotate(
117 '', xy=(0, 0), ha = 'right',
118 xytext = self.offsets, textcoords = 'offset points', va = 'bottom',
119 bbox = dict(
120 boxstyle='round,pad=0.5', fc='yellow', alpha=0.75),
121 arrowprops = dict(
122 arrowstyle='->', connectionstyle='arc3,rad=0'))
123 return annotation
125 def snap(self, x, y):
126 """Return the value in self.tree closest to x, y."""
127 dist, idx = self.tree.query(self.scaled((x, y)), k=1, p=1)
128 try:
129 return self._points[idx]
130 except IndexError:
131 # IndexError: index out of bounds
132 return self._points[0]
135 def plot_accuracy(top_k, data, value_at_hover=False):
136 nets = data.keys()
137 colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
138 fig = plt.figure()
139 ax = fig.add_subplot(111)
140 for net in nets:
141 iteration = data[net]["accuracy"]["iteration"]
142 accuracy = data[net]["accuracy"]["accuracy"]
143 iteration, accuracy = (np.array(t) for t in zip(*sorted(zip(iteration, accuracy))))
144 ax.plot(iteration, accuracy*100, color=next(colors), linestyle='-')
145 if value_at_hover:
146 cursor = FollowDotCursor(ax, iteration, accuracy*100)
148 plt.legend(nets, loc='lower right')
149 plt.title("Top {}".format(top_k))
150 plt.xlabel("Iteration")
151 plt.ylabel("Accuracy [%]")
152 plt.ylim(0,100)
153 plt.grid()
154 return plt
157 def plot_loss(data, value_at_hover=False):
158 nets = data.keys()
159 colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
160 fig = plt.figure()
161 ax = fig.add_subplot(111)
162 for net in nets:
163 iteration = data[net]["loss"]["iteration"]
164 loss = data[net]["loss"]["loss"]
165 iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
166 ax.scatter(iteration, loss, color=next(colors))
167 if value_at_hover:
168 cursor = FollowDotCursor(ax, iteration, loss)
170 plt.legend(nets, loc='upper right')
171 plt.title("Log Loss")
172 plt.xlabel("Iteration")
173 plt.ylabel("Log Loss")
174 plt.xlim(0)
175 plt.grid()
176 return plt