ce0d24e549953fd41765141d57a61a8a442a40f4
1 // Copyright 2013 Yangqing Jia
3 #include <cuda_runtime.h>
4 #include <leveldb/db.h>
6 #include <string>
8 #include "gtest/gtest.h"
9 #include "caffe/blob.hpp"
10 #include "caffe/common.hpp"
11 #include "caffe/filler.hpp"
12 #include "caffe/vision_layers.hpp"
13 #include "caffe/proto/caffe.pb.h"
14 #include "caffe/test/test_caffe_main.hpp"
16 using std::string;
18 namespace caffe {
20 extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
22 template <typename Dtype>
23 class DataLayerTest : public ::testing::Test {
24 protected:
25 DataLayerTest()
26 : blob_top_data_(new Blob<Dtype>()),
27 blob_top_label_(new Blob<Dtype>()),
28 filename(NULL) {};
29 virtual void SetUp() {
30 blob_top_vec_.push_back(blob_top_data_);
31 blob_top_vec_.push_back(blob_top_label_);
32 // Create the leveldb
33 filename = tmpnam(NULL); // get temp name
34 LOG(ERROR) << "Using temporary leveldb " << filename;
35 leveldb::DB* db;
36 leveldb::Options options;
37 options.error_if_exists = true;
38 options.create_if_missing = true;
39 leveldb::Status status = leveldb::DB::Open(options, filename, &db);
40 CHECK(status.ok());
41 for (int i = 0; i < 5; ++i) {
42 Datum datum;
43 datum.set_label(i);
44 datum.set_channels(2);
45 datum.set_height(3);
46 datum.set_width(4);
47 std::string* data = datum.mutable_data();
48 for (int j = 0; j < 24; ++j) {
49 data->push_back((uint8_t)i);
50 }
51 stringstream ss;
52 ss << i;
53 db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
54 }
55 delete db;
56 };
58 virtual ~DataLayerTest() { delete blob_top_data_; delete blob_top_label_; }
60 char* filename;
61 Blob<Dtype>* const blob_top_data_;
62 Blob<Dtype>* const blob_top_label_;
63 vector<Blob<Dtype>*> blob_bottom_vec_;
64 vector<Blob<Dtype>*> blob_top_vec_;
65 };
67 typedef ::testing::Types<float, double> Dtypes;
68 TYPED_TEST_CASE(DataLayerTest, Dtypes);
70 TYPED_TEST(DataLayerTest, TestRead) {
71 LayerParameter param;
72 param.set_batchsize(5);
73 param.set_source(this->filename);
74 DataLayer<TypeParam> layer(param);
75 layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
76 EXPECT_EQ(this->blob_top_data_->num(), 5);
77 EXPECT_EQ(this->blob_top_data_->channels(), 2);
78 EXPECT_EQ(this->blob_top_data_->height(), 3);
79 EXPECT_EQ(this->blob_top_data_->width(), 4);
80 EXPECT_EQ(this->blob_top_label_->num(), 5);
81 EXPECT_EQ(this->blob_top_label_->channels(), 1);
82 EXPECT_EQ(this->blob_top_label_->height(), 1);
83 EXPECT_EQ(this->blob_top_label_->width(), 1);
84 // Go throught the data twice
85 for (int iter = 0; iter < 2; ++iter) {
86 layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
87 for (int i = 0; i < 5; ++i) {
88 EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]);
89 }
90 for (int i = 0; i < 5; ++i) {
91 for (int j = 0; j < 24; ++j) {
92 EXPECT_EQ(i, this->blob_top_data_->cpu_data()[i * 24 + j])
93 << "debug: i " << i << " j " << j;
94 }
95 }
96 }
97 }
99 }