4c8d3fd2a438be49b4e09b038c46f0b70eac4d76
[jacinto-ai/caffe-jacinto.git] / src / caffe / test / test_solver_mnist.cpp
1 // Copyright 2013 Yangqing Jia
3 #include <cuda_runtime.h>
4 #include <fcntl.h>
5 #include <google/protobuf/text_format.h>
6 #include <google/protobuf/io/zero_copy_stream_impl.h>
7 #include <gtest/gtest.h>
9 #include <cstring>
11 #include "caffe/blob.hpp"
12 #include "caffe/common.hpp"
13 #include "caffe/net.hpp"
14 #include "caffe/filler.hpp"
15 #include "caffe/proto/caffe.pb.h"
16 #include "caffe/util/io.hpp"
17 #include "caffe/optimization/solver.hpp"
19 #include "caffe/test/test_caffe_main.hpp"
21 namespace caffe {
23 template <typename Dtype>
24 class MNISTSolverTest : public ::testing::Test {};
26 typedef ::testing::Types<float> Dtypes;
27 TYPED_TEST_CASE(MNISTSolverTest, Dtypes);
29 TYPED_TEST(MNISTSolverTest, TestSolve) {
30   Caffe::set_mode(Caffe::GPU);
32   NetParameter net_param;
33   ReadProtoFromTextFile("caffe/test/data/lenet.prototxt",
34       &net_param);
35   vector<Blob<TypeParam>*> bottom_vec;
36   Net<TypeParam> caffe_net(net_param, bottom_vec);
38   // Run the network without training.
39   LOG(ERROR) << "Performing Forward";
40   caffe_net.Forward(bottom_vec);
41   LOG(ERROR) << "Performing Backward";
42   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
44   SolverParameter solver_param;
45   solver_param.set_base_lr(0.01);
46   solver_param.set_display(0);
47   solver_param.set_max_iter(6000);
48   solver_param.set_lr_policy("inv");
49   solver_param.set_gamma(0.0001);
50   solver_param.set_power(0.75);
51   solver_param.set_momentum(0.9);
53   LOG(ERROR) << "Starting Optimization";
54   SGDSolver<TypeParam> solver(solver_param);
55   solver.Solve(&caffe_net);
56   LOG(ERROR) << "Optimization Done.";
58   // Run the network after training.
59   LOG(ERROR) << "Performing Forward";
60   caffe_net.Forward(bottom_vec);
61   LOG(ERROR) << "Performing Backward";
62   TypeParam loss = caffe_net.Backward();
63   LOG(ERROR) << "Final loss: " << loss;
64   EXPECT_LE(loss, 0.5);
66   NetParameter trained_net_param;
67   caffe_net.ToProto(&trained_net_param);
68   // LOG(ERROR) << "Writing to disk.";
69   // WriteProtoToBinaryFile(trained_net_param,
70   //     "caffe/test/data/lenet_trained.prototxt");
72   NetParameter traintest_net_param;
73   ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
74       &traintest_net_param);
75   Net<TypeParam> caffe_traintest_net(traintest_net_param, bottom_vec);
76   caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
78   // Test run
79   double train_accuracy = 0;
80   int batch_size = traintest_net_param.layers(0).layer().batchsize();
81   for (int i = 0; i < 60000 / batch_size; ++i) {
82     const vector<Blob<TypeParam>*>& result =
83         caffe_traintest_net.Forward(bottom_vec);
84     train_accuracy += result[0]->cpu_data()[0];
85   }
86   train_accuracy /= 60000 / batch_size;
87   LOG(ERROR) << "Train accuracy:" << train_accuracy;
88   EXPECT_GE(train_accuracy, 0.98);
90   NetParameter test_net_param;
91   ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
92   Net<TypeParam> caffe_test_net(test_net_param, bottom_vec);
93   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
95   // Test run
96   double test_accuracy = 0;
97   batch_size = test_net_param.layers(0).layer().batchsize();
98   for (int i = 0; i < 10000 / batch_size; ++i) {
99     const vector<Blob<TypeParam>*>& result =
100         caffe_test_net.Forward(bottom_vec);
101     test_accuracy += result[0]->cpu_data()[0];
102   }
103   test_accuracy /= 10000 / batch_size;
104   LOG(ERROR) << "Test accuracy:" << test_accuracy;
105   EXPECT_GE(test_accuracy, 0.98);
108 }  // namespace caffe