summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: a17ea59)
raw | patch | inline | side by side (parent: a17ea59)
author | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 24 Sep 2013 22:45:15 +0000 (15:45 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 24 Sep 2013 22:45:15 +0000 (15:45 -0700) |
src/caffe/layers/data_layer.cpp | patch | blob | history | |
src/caffe/test/test_convolution_layer.cpp | patch | blob | history | |
src/caffe/test/test_data_layer.cpp | [new file with mode: 0644] | patch | blob |
src/caffe/test/test_protobuf.cpp | patch | blob | history |
index 4b7f29d2873261b99855c13f6f5fb393967c10e5..3cd76d124cb7c7cef07890a858535d4217bb2289 100644 (file)
leveldb::DB* db_temp;
leveldb::Options options;
options.create_if_missing = false;
leveldb::DB* db_temp;
leveldb::Options options;
options.create_if_missing = false;
+ LOG(INFO) << "Opening leveldb " << this->layer_param_.source();
leveldb::Status status = leveldb::DB::Open(
options, this->layer_param_.source(), &db_temp);
CHECK(status.ok());
db_.reset(db_temp);
iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
leveldb::Status status = leveldb::DB::Open(
options, this->layer_param_.source(), &db_temp);
CHECK(status.ok());
db_.reset(db_temp);
iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
+ iter_->SeekToFirst();
// Read a data point, and use it to initialize the top blob.
Datum datum;
datum.ParseFromString(iter_->value().ToString());
// Read a data point, and use it to initialize the top blob.
Datum datum;
datum.ParseFromString(iter_->value().ToString());
vector<Blob<Dtype>*>* top) {
Datum datum;
Dtype* top_data = (*top)[0]->mutable_cpu_data();
vector<Blob<Dtype>*>* top) {
Datum datum;
Dtype* top_data = (*top)[0]->mutable_cpu_data();
- Dtype* top_label = (*top)[0]->mutable_cpu_diff();
+ Dtype* top_label = (*top)[1]->mutable_cpu_data();
for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
// get a blob
datum.ParseFromString(iter_->value().ToString());
for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
// get a blob
datum.ParseFromString(iter_->value().ToString());
if (!iter_->Valid()) {
// We have reached the end. Restart from the first.
LOG(INFO) << "Restarting data read from start.";
if (!iter_->Valid()) {
// We have reached the end. Restart from the first.
LOG(INFO) << "Restarting data read from start.";
- iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
+ iter_->SeekToFirst();
}
}
}
}
}
}
index 54517aac9f357f43ff92e02a1981966f95193a6c..5de33bc591e75e891dafaa587f3b8533002f28de 100644 (file)
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
-
+
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
template <typename Dtype>
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
template <typename Dtype>
diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp
--- /dev/null
@@ -0,0 +1,97 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cstring>
+#include <cuda_runtime.h>
+#include <leveldb/db.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/proto/layer_param.pb.h"
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class DataLayerTest : public ::testing::Test {
+ protected:
+ DataLayerTest()
+ : blob_top_data_(new Blob<Dtype>()),
+ blob_top_label_(new Blob<Dtype>()),
+ filename(NULL) {};
+ virtual void SetUp() {
+ blob_top_vec_.push_back(blob_top_data_);
+ blob_top_vec_.push_back(blob_top_label_);
+ // Create the leveldb
+ filename = tmpnam(NULL); // get temp name
+ LOG(ERROR) << "Using temporary leveldb " << filename;
+ leveldb::DB* db;
+ leveldb::Options options;
+ options.error_if_exists = true;
+ options.create_if_missing = true;
+ leveldb::Status status = leveldb::DB::Open(options, filename, &db);
+ CHECK(status.ok());
+ for (int i = 0; i < 5; ++i) {
+ Datum datum;
+ datum.set_label(i);
+ BlobProto* blob = datum.mutable_blob();
+ blob->set_num(1);
+ blob->set_channels(2);
+ blob->set_height(3);
+ blob->set_width(4);
+ for (int j = 0; j < 24; ++j) {
+ blob->add_data(i);
+ }
+ stringstream ss;
+ ss << i;
+ db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
+ }
+ delete db;
+ };
+
+ virtual ~DataLayerTest() { delete blob_top_data_; delete blob_top_label_; }
+
+ char* filename;
+ Blob<Dtype>* const blob_top_data_;
+ Blob<Dtype>* const blob_top_label_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(DataLayerTest, Dtypes);
+
+TYPED_TEST(DataLayerTest, TestRead) {
+ LayerParameter param;
+ param.set_batchsize(5);
+ param.set_source(this->filename);
+ DataLayer<TypeParam> layer(param);
+ layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_data_->num(), 5);
+ EXPECT_EQ(this->blob_top_data_->channels(), 2);
+ EXPECT_EQ(this->blob_top_data_->height(), 3);
+ EXPECT_EQ(this->blob_top_data_->width(), 4);
+ EXPECT_EQ(this->blob_top_label_->num(), 5);
+ EXPECT_EQ(this->blob_top_label_->channels(), 1);
+ EXPECT_EQ(this->blob_top_label_->height(), 1);
+ EXPECT_EQ(this->blob_top_label_->width(), 1);
+ // Go throught the data twice
+ for (int iter = 0; iter < 2; ++iter) {
+ layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]);
+ }
+ for (int i = 0; i < 5; ++i) {
+ for (int j = 0; j < 24; ++j) {
+ EXPECT_EQ(i, this->blob_top_data_->cpu_data()[i * 24 + j])
+ << "debug: i " << i << " j " << j;
+ }
+ }
+ }
+}
+
+}
index 9ff1d71fcc7b07c625352b60f1d2ab89433c7a36..87dffa072c0307b2e2fa1ca696f1f6e19d100082 100644 (file)
#include "caffe/proto/layer_param.pb.h"
namespace caffe {
#include "caffe/proto/layer_param.pb.h"
namespace caffe {
-
+
class ProtoTest : public ::testing::Test {};
TEST_F(ProtoTest, TestSerialization) {
class ProtoTest : public ::testing::Test {};
TEST_F(ProtoTest, TestSerialization) {