]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - examples/test_net.cpp
point caffe url to bvlc
[jacinto-ai/caffe-jacinto.git] / examples / test_net.cpp
1 // Copyright 2013 Yangqing Jia
2 //
3 // This is a simple script that allows one to quickly test a network whose
4 // structure is specified by text format protocol buffers, and whose parameter
5 // are loaded from a pre-trained network.
6 // Usage:
7 //    test_net net_proto pretrained_net_proto iterations [CPU/GPU]
9 #include <cuda_runtime.h>
11 #include <cstring>
12 #include <cstdlib>
14 #include "caffe/caffe.hpp"
16 using namespace caffe;
18 int main(int argc, char** argv) {
19   if (argc < 4) {
20     LOG(ERROR) << "test_net net_proto pretrained_net_proto iterations [CPU/GPU]";
21     return 0;
22   }
24   cudaSetDevice(0);
25   Caffe::set_phase(Caffe::TEST);
27   if (argc == 5 && strcmp(argv[4], "GPU") == 0) {
28     LOG(ERROR) << "Using GPU";
29     Caffe::set_mode(Caffe::GPU);
30   } else {
31     LOG(ERROR) << "Using CPU";
32     Caffe::set_mode(Caffe::CPU);
33   }
35   NetParameter test_net_param;
36   ReadProtoFromTextFile(argv[1], &test_net_param);
37   Net<float> caffe_test_net(test_net_param);
38   NetParameter trained_net_param;
39   ReadProtoFromBinaryFile(argv[2], &trained_net_param);
40   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
42   int total_iter = atoi(argv[3]);
43   LOG(ERROR) << "Running " << total_iter << "Iterations.";
45   double test_accuracy = 0;
46   vector<Blob<float>*> dummy_blob_input_vec;
47   for (int i = 0; i < total_iter; ++i) {
48     const vector<Blob<float>*>& result =
49         caffe_test_net.Forward(dummy_blob_input_vec);
50     test_accuracy += result[0]->cpu_data()[0];
51     LOG(ERROR) << "Batch " << i << ", accuracy: " << result[0]->cpu_data()[0];
52   }
53   test_accuracy /= total_iter;
54   LOG(ERROR) << "Test accuracy:" << test_accuracy;
56   return 0;
57 }