b2403460138de4aae6e629eb3fcd0b42f813f695
1 // Copyright 2013 Yangqing Jia
3 #include <cuda_runtime.h>
4 #include <fcntl.h>
5 #include <google/protobuf/text_format.h>
6 #include <google/protobuf/io/zero_copy_stream_impl.h>
7 #include <gtest/gtest.h>
9 #include <cstring>
11 #include "caffe/blob.hpp"
12 #include "caffe/common.hpp"
13 #include "caffe/net.hpp"
14 #include "caffe/filler.hpp"
15 #include "caffe/proto/caffe.pb.h"
16 #include "caffe/util/io.hpp"
18 #include "caffe/test/lenet.hpp"
19 #include "caffe/test/test_caffe_main.hpp"
21 namespace caffe {
23 template <typename Dtype>
24 class NetProtoTest : public ::testing::Test {};
26 typedef ::testing::Types<float, double> Dtypes;
27 TYPED_TEST_CASE(NetProtoTest, Dtypes);
29 TYPED_TEST(NetProtoTest, TestLoadFromText) {
30 NetParameter net_param;
31 ReadProtoFromTextFile("caffe/test/data/lenet.prototxt", &net_param);
32 }
34 TYPED_TEST(NetProtoTest, TestSetup) {
35 NetParameter net_param;
36 string lenet_string(kLENET);
37 // Load the network
38 CHECK(google::protobuf::TextFormat::ParseFromString(
39 lenet_string, &net_param));
40 // check if things are right
41 EXPECT_EQ(net_param.layers_size(), 9);
42 EXPECT_EQ(net_param.bottom_size(), 2);
43 EXPECT_EQ(net_param.top_size(), 0);
45 // Now, initialize a network using the parameter
46 shared_ptr<Blob<TypeParam> > data(new Blob<TypeParam>(10, 1, 28, 28));
47 shared_ptr<Blob<TypeParam> > label(new Blob<TypeParam>(10, 1, 1, 1));
48 FillerParameter filler_param;
49 shared_ptr<Filler<TypeParam> > filler;
50 filler.reset(new ConstantFiller<TypeParam>(filler_param));
51 filler->Fill(label.get());
52 filler.reset(new UniformFiller<TypeParam>(filler_param));
53 filler->Fill(data.get());
55 vector<Blob<TypeParam>*> bottom_vec;
56 bottom_vec.push_back(data.get());
57 bottom_vec.push_back(label.get());
59 Net<TypeParam> caffe_net(net_param, bottom_vec);
60 EXPECT_EQ(caffe_net.layer_names().size(), 9);
61 EXPECT_EQ(caffe_net.blob_names().size(), 10);
63 // Print a few statistics to see if things are correct
64 for (int i = 0; i < caffe_net.blobs().size(); ++i) {
65 LOG(ERROR) << "Blob: " << caffe_net.blob_names()[i];
66 LOG(ERROR) << "size: " << caffe_net.blobs()[i]->num() << ", "
67 << caffe_net.blobs()[i]->channels() << ", "
68 << caffe_net.blobs()[i]->height() << ", "
69 << caffe_net.blobs()[i]->width();
70 }
71 // Run the network without training.
72 vector<Blob<TypeParam>*> top_vec;
73 LOG(ERROR) << "Performing Forward";
74 caffe_net.Forward(bottom_vec, &top_vec);
75 LOG(ERROR) << "Performing Backward";
76 LOG(ERROR) << caffe_net.Backward();
77 }
79 } // namespace caffe