]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - common_plot.py
Merge branch 'orig-caffe-0.16' into nvcaffe-0.16-synced-shuffled
[jacinto-ai/caffe-jacinto.git] / common_plot.py
1 import re
2 import os
3 import matplotlib.pyplot as plt
4 import matplotlib.cm as cm
5 import numpy as np
6 import scipy.spatial as spatial
9 def get_test_accuracy(log, top_k):
10     iteration = re.findall(r'Iteration (\d*), Testing net \(#0\)', log)
11     accuracy = re.findall(r'Test net output #\d: accuracy/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
12     if len(accuracy)==0:
13         accuracy = re.findall(r'Test net output #\d: top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
14     if len(accuracy)==0:
15         accuracy = re.findall(r'Test net output #\d: loss/top-{top_k} = (\d*.\d*)'.format(top_k=top_k), log)
16     iteration = [int(i) for i in iteration]
17     accuracy = [float(i) for i in accuracy]
18     return iteration, accuracy
21 def get_test_loss(log):
22     iteration = re.findall(r'Iteration (\d*), Testing net ', log)
23     loss = re.findall(r'Test net output #\d: loss = (\d*.\d*)', log)
24     if len(loss)==0:
25         loss = re.findall(r'Test net output #\d: loss/loss = (\d*.\d*)', log)
26     iteration = [int(i) for i in iteration]
27     loss = [float(i) for i in loss]
28     return iteration, loss
31 def get_net_name(log):
32     return re.findall(r"Solving (.*)\n", log)[0]
35 def parse_files(files, top_k=1, separate=False):
36     data = {}
37     for file in files:
38         with open(file, 'r') as fp:
39             log = fp.read()
40             net_name = os.path.basename(file) if separate else get_net_name(log)
41             if net_name not in data.keys():
42                 data[net_name] = {}
43                 data[net_name]["accuracy"] = {}
44                 data[net_name]["accuracy"]["accuracy"] = []
45                 data[net_name]["accuracy"]["iteration"] = []
46                 data[net_name]["loss"] = {}
47                 data[net_name]["loss"]["loss"] = []
48                 data[net_name]["loss"]["iteration"] = []
49             iteration, accuracy = get_test_accuracy(log, top_k)
50             data[net_name]["accuracy"]["iteration"].extend(iteration)
51             data[net_name]["accuracy"]["accuracy"].extend(accuracy)
53             iteration, loss = get_test_loss(log)
54             data[net_name]["loss"]["iteration"].extend(iteration)
55             data[net_name]["loss"]["loss"].extend(loss)
56     return data
59 def fmt(x, y):
60     return 'x: {x:0.2f}\ny: {y:0.2f}'.format(x=x, y=y)
63 class FollowDotCursor(object):
64     """Display the x,y location of the nearest data point.
65     http://stackoverflow.com/a/4674445/190597 (Joe Kington)
66     http://stackoverflow.com/a/20637433/190597 (unutbu)
67     """
68     def __init__(self, ax, x, y, formatter=fmt, offsets=(-20, 20)):
69         try:
70             x = np.asarray(x, dtype='float')
71         except (TypeError, ValueError):
72             x = np.asarray(mdates.date2num(x), dtype='float')
73         y = np.asarray(y, dtype='float')
74         mask = ~(np.isnan(x) | np.isnan(y))
75         x = x[mask]
76         y = y[mask]
77         self._points = np.column_stack((x, y))
78         self.offsets = offsets
79         y = y[np.abs(y - y.mean()) <= 3 * y.std()]
80         self.scale = x.ptp()
81         self.scale = y.ptp() / self.scale if self.scale else 1
82         self.tree = spatial.cKDTree(self.scaled(self._points))
83         self.formatter = formatter
84         self.ax = ax
85         self.fig = ax.figure
86         self.ax.xaxis.set_label_position('top')
87         self.dot = ax.scatter(
88             [x.min()], [y.min()], s=130, color='green', alpha=0.7)
89         self.annotation = self.setup_annotation()
90         plt.connect('motion_notify_event', self)
92     def scaled(self, points):
93         points = np.asarray(points)
94         return points * (self.scale, 1)
96     def __call__(self, event):
97         ax = self.ax
98         # event.inaxes is always the current axis. If you use twinx, ax could be
99         # a different axis.
100         if event.inaxes == ax:
101             x, y = event.xdata, event.ydata
102         elif event.inaxes is None:
103             return
104         else:
105             inv = ax.transData.inverted()
106             x, y = inv.transform([(event.x, event.y)]).ravel()
107         annotation = self.annotation
108         x, y = self.snap(x, y)
109         annotation.xy = x, y
110         annotation.set_text(self.formatter(x, y))
111         self.dot.set_offsets((x, y))
112         event.canvas.draw()
114     def setup_annotation(self):
115         """Draw and hide the annotation box."""
116         annotation = self.ax.annotate(
117             '', xy=(0, 0), ha = 'right',
118             xytext = self.offsets, textcoords = 'offset points', va = 'bottom',
119             bbox = dict(
120                 boxstyle='round,pad=0.5', fc='yellow', alpha=0.75),
121             arrowprops = dict(
122                 arrowstyle='->', connectionstyle='arc3,rad=0'))
123         return annotation
125     def snap(self, x, y):
126         """Return the value in self.tree closest to x, y."""
127         dist, idx = self.tree.query(self.scaled((x, y)), k=1, p=1)
128         try:
129             return self._points[idx]
130         except IndexError:
131             # IndexError: index out of bounds
132             return self._points[0]
135 def plot_accuracy(top_k, data, value_at_hover=False):
136     nets =  data.keys()
137     colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
138     fig = plt.figure()
139     ax = fig.add_subplot(111)
140     for net in nets:
141         iteration = data[net]["accuracy"]["iteration"]
142         accuracy = data[net]["accuracy"]["accuracy"]
143         iteration, accuracy = (np.array(t) for t in zip(*sorted(zip(iteration, accuracy))))
144         ax.plot(iteration, accuracy*100, color=next(colors), linestyle='-')
145         if value_at_hover:
146             cursor = FollowDotCursor(ax, iteration, accuracy*100)
148     plt.legend(nets, loc='lower right')
149     plt.title("Top {}".format(top_k))
150     plt.xlabel("Iteration")
151     plt.ylabel("Accuracy [%]")
152     plt.ylim(0,100)
153     plt.grid()
154     return plt
157 def plot_loss(data, value_at_hover=False):
158     nets =  data.keys()
159     colors = iter(cm.rainbow(np.linspace(0, 1, len(nets))))
160     fig = plt.figure()
161     ax = fig.add_subplot(111)
162     for net in nets:
163         iteration = data[net]["loss"]["iteration"]
164         loss = data[net]["loss"]["loss"]
165         iteration, loss = (list(t) for t in zip(*sorted(zip(iteration, loss))))
166         ax.scatter(iteration, loss, color=next(colors))
167         if value_at_hover:
168             cursor = FollowDotCursor(ax, iteration, loss)
170     plt.legend(nets, loc='upper right')
171     plt.title("Log Loss")
172     plt.xlabel("Iteration")
173     plt.ylabel("Log Loss")
174     plt.xlim(0)
175     plt.grid()
176     return plt