]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - examples/demo_mnist.cpp
Merge branch 'master' of github.com:Yangqing/caffeine
[jacinto-ai/caffe-jacinto.git] / examples / demo_mnist.cpp
1 // Copyright 2013 Yangqing Jia
2 // This example shows how to run a modified version of LeNet using Caffe.
4 #include <cuda_runtime.h>
5 #include <fcntl.h>
6 #include <google/protobuf/text_format.h>
8 #include <cstring>
9 #include <iostream>
11 #include "caffe/blob.hpp"
12 #include "caffe/common.hpp"
13 #include "caffe/net.hpp"
14 #include "caffe/filler.hpp"
15 #include "caffe/proto/caffe.pb.h"
16 #include "caffe/util/io.hpp"
17 #include "caffe/solver.hpp"
19 using namespace caffe;
21 int main(int argc, char** argv) {
22   if (argc < 3) {
23     std::cout << "Usage:" << std::endl;
24     std::cout << "demo_mnist.bin train_file test_file [CPU/GPU]" << std::endl;
25     return 0;
26   }
27   google::InitGoogleLogging(argv[0]);
28   Caffe::DeviceQuery();
30   if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
31     LOG(ERROR) << "Using GPU";
32     Caffe::set_mode(Caffe::GPU);
33   } else {
34     LOG(ERROR) << "Using CPU";
35     Caffe::set_mode(Caffe::CPU);
36   }
38   // Start training
39   Caffe::set_phase(Caffe::TRAIN);
41   NetParameter net_param;
42   ReadProtoFromTextFile(argv[1],
43       &net_param);
44   vector<Blob<float>*> bottom_vec;
45   Net<float> caffe_net(net_param, bottom_vec);
47   // Run the network without training.
48   LOG(ERROR) << "Performing Forward";
49   caffe_net.Forward(bottom_vec);
50   LOG(ERROR) << "Performing Backward";
51   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
53   SolverParameter solver_param;
54   // Solver Parameters are hard-coded in this case, but you can write a
55   // SolverParameter protocol buffer to specify all these values.
56   solver_param.set_base_lr(0.01);
57   solver_param.set_display(100);
58   solver_param.set_max_iter(5000);
59   solver_param.set_lr_policy("inv");
60   solver_param.set_gamma(0.0001);
61   solver_param.set_power(0.75);
62   solver_param.set_momentum(0.9);
63   solver_param.set_weight_decay(0.0005);
65   LOG(ERROR) << "Starting Optimization";
66   SGDSolver<float> solver(solver_param);
67   solver.Solve(&caffe_net);
68   LOG(ERROR) << "Optimization Done.";
70   // Write the trained network to a NetParameter protobuf. If you are training
71   // the model and saving it for later, this is what you want to serialize and
72   // store.
73   NetParameter trained_net_param;
74   caffe_net.ToProto(&trained_net_param);
76   // Now, let's starting doing testing.
77   Caffe::set_phase(Caffe::TEST);
79   // Using the testing data to test the accuracy.
80   NetParameter test_net_param;
81   ReadProtoFromTextFile(argv[2], &test_net_param);
82   Net<float> caffe_test_net(test_net_param, bottom_vec);
83   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
85   double test_accuracy = 0;
86   int batch_size = test_net_param.layers(0).layer().batchsize();
87   for (int i = 0; i < 10000 / batch_size; ++i) {
88     const vector<Blob<float>*>& result =
89         caffe_test_net.Forward(bottom_vec);
90     test_accuracy += result[0]->cpu_data()[0];
91   }
92   test_accuracy /= 10000 / batch_size;
93   LOG(ERROR) << "Test accuracy:" << test_accuracy;
95   return 0;
96 }