add mnist data
[jacinto-ai/caffe-jacinto.git] / src / caffe / pyutil / drawnet.py
1 """Functions to draw a caffe NetParameter protobuffer.
2 """
4 from caffe.proto import caffe_pb2
5 from google.protobuf import text_format
6 import pydot
7 import os
8 import sys
10 # Internal layer and blob styles.
11 LAYER_STYLE = {'shape': 'record', 'fillcolor': '#6495ED',
12          'style': 'filled'}
13 NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
14          'style': 'filled'}
15 BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
16         'style': 'filled'}
18 def get_pydot_graph(caffe_net):
19   pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
20   pydot_nodes = {}
21   pydot_edges = []
22   for layer in caffe_net.layers:
23     name = layer.layer.name
24     layertype = layer.layer.type
25     if (len(layer.bottom) == 1 and len(layer.top) == 1 and
26         layer.bottom[0] == layer.top[0]):
27       # We have an in-place neuron layer.
28       pydot_nodes[name + '_' + layertype] = pydot.Node(
29           '%s (%s)' % (name, layertype), **NEURON_LAYER_STYLE)
30     else:
31       pydot_nodes[name + '_' + layertype] = pydot.Node(
32           '%s (%s)' % (name, layertype), **LAYER_STYLE)
33     for bottom_blob in layer.bottom:
34       pydot_nodes[bottom_blob + '_blob'] = pydot.Node(
35         '%s' % (bottom_blob), **BLOB_STYLE)
36       pydot_edges.append((bottom_blob + '_blob', name + '_' + layertype))
37     for top_blob in layer.top:
38       pydot_nodes[top_blob + '_blob'] = pydot.Node(
39         '%s' % (top_blob))
40       pydot_edges.append((name + '_' + layertype, top_blob + '_blob'))
41   # Now, add the nodes and edges to the graph.
42   for node in pydot_nodes.values():
43     pydot_graph.add_node(node)
44   for edge in pydot_edges:
45     pydot_graph.add_edge(
46         pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
47   return pydot_graph
49 def draw_net(caffe_net, ext='png'):
50   """Draws a caffe net and returns the image string encoded using the given
51   extension.
53   Input:
54     caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
55     ext: the image extension. Default 'png'.
56   """
57   return get_pydot_graph(caffe_net).create(format=ext)
59 def draw_net_to_file(caffe_net, filename):
60   """Draws a caffe net, and saves it to file using the format given as the
61   file extension. Use '.raw' to output raw text that you can manually feed
62   to graphviz to draw graphs.
63   """
64   ext = filename[filename.rfind('.')+1:]
65   with open(filename, 'w') as fid:
66     fid.write(draw_net(caffe_net, ext))
68 if __name__ == '__main__':
69   if len(sys.argv) != 3:
70     print 'Usage: %s input_net_proto_file output_image_file' % \
71         os.path.basename(sys.argv[0])
72   else:
73     net = caffe_pb2.NetParameter()
74     text_format.Merge(open(sys.argv[1]).read(), net)
75     print 'Drawing net to %s' % sys.argv[2]
76     draw_net_to_file(net, sys.argv[2])