]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - examples/demo_mnist.cpp
Reorganization of codes.
[jacinto-ai/caffe-jacinto.git] / examples / demo_mnist.cpp
1 // Copyright 2013 Yangqing Jia
3 #include <cuda_runtime.h>
4 #include <fcntl.h>
5 #include <google/protobuf/text_format.h>
7 #include <cstring>
9 #include "caffe/blob.hpp"
10 #include "caffe/common.hpp"
11 #include "caffe/net.hpp"
12 #include "caffe/filler.hpp"
13 #include "caffe/proto/caffe.pb.h"
14 #include "caffe/util/io.hpp"
15 #include "caffe/solver.hpp"
17 using namespace caffe;
19 int main(int argc, char** argv) {
20   cudaSetDevice(1);
21   Caffe::set_mode(Caffe::GPU);
22   Caffe::set_phase(Caffe::TRAIN);
24   NetParameter net_param;
25   ReadProtoFromTextFile("data/lenet.prototxt",
26       &net_param);
27   vector<Blob<float>*> bottom_vec;
28   Net<float> caffe_net(net_param, bottom_vec);
30   // Run the network without training.
31   LOG(ERROR) << "Performing Forward";
32   caffe_net.Forward(bottom_vec);
33   LOG(ERROR) << "Performing Backward";
34   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
36   SolverParameter solver_param;
37   solver_param.set_base_lr(0.01);
38   solver_param.set_display(100);
39   solver_param.set_max_iter(6000);
40   solver_param.set_lr_policy("inv");
41   solver_param.set_gamma(0.0001);
42   solver_param.set_power(0.75);
43   solver_param.set_momentum(0.9);
44   solver_param.set_weight_decay(0.0005);
46   LOG(ERROR) << "Starting Optimization";
47   SGDSolver<float> solver(solver_param);
48   solver.Solve(&caffe_net);
49   LOG(ERROR) << "Optimization Done.";
51   // Run the network after training.
52   LOG(ERROR) << "Performing Forward";
53   caffe_net.Forward(bottom_vec);
54   LOG(ERROR) << "Performing Backward";
55   float loss = caffe_net.Backward();
56   LOG(ERROR) << "Final loss: " << loss;
58   NetParameter trained_net_param;
59   caffe_net.ToProto(&trained_net_param);
61   NetParameter traintest_net_param;
62   ReadProtoFromTextFile("data/lenet_traintest.prototxt",
63       &traintest_net_param);
64   Net<float> caffe_traintest_net(traintest_net_param, bottom_vec);
65   caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
67   Caffe::set_phase(Caffe::TEST);
69   // Test run
70   double train_accuracy = 0;
71   int batch_size = traintest_net_param.layers(0).layer().batchsize();
72   for (int i = 0; i < 60000 / batch_size; ++i) {
73     const vector<Blob<float>*>& result =
74         caffe_traintest_net.Forward(bottom_vec);
75     train_accuracy += result[0]->cpu_data()[0];
76   }
77   train_accuracy /= 60000 / batch_size;
78   LOG(ERROR) << "Train accuracy:" << train_accuracy;
80   NetParameter test_net_param;
81   ReadProtoFromTextFile("data/lenet_test.prototxt", &test_net_param);
82   Net<float> caffe_test_net(test_net_param, bottom_vec);
83   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
85   // Test run
86   double test_accuracy = 0;
87   batch_size = test_net_param.layers(0).layer().batchsize();
88   for (int i = 0; i < 10000 / batch_size; ++i) {
89     const vector<Blob<float>*>& result =
90         caffe_test_net.Forward(bottom_vec);
91     test_accuracy += result[0]->cpu_data()[0];
92   }
93   test_accuracy /= 10000 / batch_size;
94   LOG(ERROR) << "Test accuracy:" << test_accuracy;
96   return 0;
97 }