summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: fb28244)
raw | patch | inline | side by side (parent: fb28244)
author | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 1 Oct 2013 23:27:48 +0000 (16:27 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 1 Oct 2013 23:27:48 +0000 (16:27 -0700) |
.gitignore | patch | blob | history | |
src/Makefile | patch | blob | history | |
src/caffe/layers/loss_layer.cu | patch | blob | history | |
src/caffe/layers/softmax_layer.cpp | [deleted file] | patch | blob | history |
src/caffe/layers/softmax_layer.cu | [new file with mode: 0644] | patch | blob |
src/caffe/test/test_solver_mnist.cpp | [deleted file] | patch | blob | history |
src/caffe/util/io.cpp | patch | blob | history | |
src/caffe/util/io.hpp | patch | blob | history | |
src/caffe/vision_layers.hpp | patch | blob | history | |
src/programs/convert_dataset.cpp | [new file with mode: 0644] | patch | blob |
diff --git a/.gitignore b/.gitignore
index 14428f6da9c9bc09e9cf8468ac589b7c9809c6c9..bc38afc9dc063fbd9fb770592dadca8144ead05c 100644 (file)
--- a/.gitignore
+++ b/.gitignore
*.pb.cc
*_pb2.py
-# test files
+# bin files
*.testbin
+*.bin
# vim swp files
*.swp
diff --git a/src/Makefile b/src/Makefile
index 05b7bc042780bdf9177a97a0b67f8c084b79c613..31d225e946c8e18572b4c9070b0f3d623a4bcfb3 100644 (file)
--- a/src/Makefile
+++ b/src/Makefile
protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
clean:
- @- $(RM) $(NAME) $(TEST_BINS)
- @- $(RM) $(OBJS) $(TEST_OBJS)
+ @- $(RM) $(NAME) $(TEST_BINS) $(PROGRAM_BINS)
+ @- $(RM) $(OBJS) $(TEST_OBJS) $(PROGRAM_OBJS)
@- $(RM) $(PROTO_GEN_HEADER) $(PROTO_GEN_CC) $(PROTO_GEN_PY)
distclean: clean
index 1ea0626c377d2f8495e4d555914626718a2dea3d..737f1a23987b60796350118550f45fa82c952d22 100644 (file)
// TODO: implement the GPU version for multinomial loss
+
template <typename Dtype>
void EuclideanLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp
+++ /dev/null
@@ -1,91 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <algorithm>
-#include <vector>
-
-#include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
-#include "caffe/util/math_functions.hpp"
-
-using std::max;
-
-namespace caffe {
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
- CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
- (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
- sum_multiplier_.Reshape(1, bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
- Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
- for (int i = 0; i < sum_multiplier_.count(); ++i) {
- multiplier_data[i] = 1.;
- }
- scale_.Reshape(bottom[0]->num(), 1, 1, 1);
-};
-
-template <typename Dtype>
-void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- Dtype* top_data = (*top)[0]->mutable_cpu_data();
- Dtype* scale_data = scale_.mutable_cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
- memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
- // we need to subtract the max to avoid numerical issues, compute the exp,
- // and then normalize.
- // Compute sum
- for (int i = 0; i < num; ++i) {
- scale_data[i] = bottom_data[i*dim];
- for (int j = 0; j < dim; ++j) {
- scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]);
- }
- }
- // subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
- scale_data, sum_multiplier_.cpu_data(), 1., top_data);
- // Perform exponentiation
- caffe_exp<Dtype>(num * dim, top_data, top_data);
- // sum after exp
- caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
- sum_multiplier_.cpu_data(), 0., scale_data);
- // Do division
- for (int i = 0; i < num; ++i) {
- caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
- }
-}
-
-template <typename Dtype>
-Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const bool propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- const Dtype* top_diff = top[0]->cpu_diff();
- const Dtype* top_data = top[0]->cpu_data();
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- Dtype* scale_data = scale_.mutable_cpu_data();
- int num = top[0]->num();
- int dim = top[0]->count() / top[0]->num();
- memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
- // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
- for (int i = 0; i < num; ++i) {
- scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
- top_data + i * dim);
- }
- // subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
- scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
- // elementwise multiplication
- caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
- return Dtype(0);
-}
-
-// TODO(Yangqing): implement the GPU version of softmax.
-
-INSTANTIATE_CLASS(SoftmaxLayer);
-
-
-} // namespace caffe
diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu
--- /dev/null
@@ -0,0 +1,181 @@
+// Copyright 2013 Yangqing Jia
+
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+#include <thrust/device_vector.h>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
+ CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
+ (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ sum_multiplier_.Reshape(1, bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+ for (int i = 0; i < sum_multiplier_.count(); ++i) {
+ multiplier_data[i] = 1.;
+ }
+ scale_.Reshape(bottom[0]->num(), 1, 1, 1);
+};
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = bottom[0]->num();
+ int dim = bottom[0]->count() / bottom[0]->num();
+ memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
+ // we need to subtract the max to avoid numerical issues, compute the exp,
+ // and then normalize.
+ for (int i = 0; i < num; ++i) {
+ scale_data[i] = bottom_data[i*dim];
+ for (int j = 0; j < dim; ++j) {
+ scale_data[i] = max(scale_data[i], bottom_data[i * dim + j]);
+ }
+ }
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_data, sum_multiplier_.cpu_data(), 1., top_data);
+ // Perform exponentiation
+ caffe_exp<Dtype>(num * dim, top_data, top_data);
+ // sum after exp
+ caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+ sum_multiplier_.cpu_data(), 0., scale_data);
+ // Do division
+ for (int i = 0; i < num; ++i) {
+ caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
+ }
+}
+
+template <typename Dtype>
+__global__ void kernel_get_max(const int num, const int dim,
+ const Dtype* data, Dtype* out) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < num) {
+ Dtype maxval = -FLT_MAX;
+ for (int i = 0; i < dim; ++i) {
+ maxval = max(data[index * dim + i], maxval);
+ }
+ out[index] = maxval;
+ }
+}
+
+template <typename Dtype>
+__global__ void kernel_softmax_div(const int num, const int dim,
+ const Dtype* scale, Dtype* data) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < num * dim) {
+ int n = index / dim;
+ data[index] /= scale[n];
+ }
+}
+
+template <typename Dtype>
+__global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < num) {
+ out[index] = exp(data[index]);
+ }
+}
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = (*top)[0]->mutable_gpu_data();
+ Dtype* scale_data = scale_.mutable_gpu_data();
+ int num = bottom[0]->num();
+ int dim = bottom[0]->count() / bottom[0]->num();
+ CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
+ sizeof(Dtype) * bottom[0]->count(), cudaMemcpyDeviceToDevice));
+ // we need to subtract the max to avoid numerical issues, compute the exp,
+ // and then normalize.
+ // Compute max
+ kernel_get_max<Dtype><<<CAFFE_GET_BLOCKS(num), CAFFE_CUDA_NUM_THREADS>>>(
+ num, dim, bottom_data, scale_data);
+ // subtraction
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_data, sum_multiplier_.gpu_data(), 1., top_data);
+ // Perform exponentiation
+ kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
+ num * dim, top_data, top_data);
+ // sum after exp
+ caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+ sum_multiplier_.gpu_data(), 0., scale_data);
+ // Do division
+ kernel_softmax_div<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
+ num, dim, scale_data, top_data);
+}
+
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* top_data = top[0]->cpu_data();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = top[0]->num();
+ int dim = top[0]->count() / top[0]->num();
+ memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
+ // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+ for (int i = 0; i < num; ++i) {
+ scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
+ top_data + i * dim);
+ }
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
+ // elementwise multiplication
+ caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+ return Dtype(0);
+}
+
+// TODO(Yangqing): implement the GPU version of softmax.
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* top_data = top[0]->gpu_data();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+ int num = top[0]->num();
+ int dim = top[0]->count() / top[0]->num();
+ CUDA_CHECK(cudaMemcpy(bottom_diff, top_diff,
+ sizeof(Dtype) * top[0]->count(), cudaMemcpyDeviceToDevice));
+ // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+ // cuda dot returns the result to cpu, so we temporarily change the pointer
+ // mode
+ CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
+ CUBLAS_POINTER_MODE_DEVICE));
+ Dtype* scale_data = scale_.mutable_gpu_data();
+ for (int i = 0; i < num; ++i) {
+ caffe_gpu_dot<Dtype>(dim, top_diff + i * dim,
+ top_data + i * dim, scale_data + i);
+ }
+ CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
+ CUBLAS_POINTER_MODE_HOST));
+ // subtraction
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_.gpu_data(), sum_multiplier_.gpu_data(), 1., bottom_diff);
+ // elementwise multiplication
+ caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+ return Dtype(0);
+}
+
+INSTANTIATE_CLASS(SoftmaxLayer);
+
+
+} // namespace caffe
diff --git a/src/caffe/test/test_solver_mnist.cpp b/src/caffe/test/test_solver_mnist.cpp
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <cuda_runtime.h>
-#include <fcntl.h>
-#include <google/protobuf/text_format.h>
-#include <google/protobuf/io/zero_copy_stream_impl.h>
-#include <gtest/gtest.h>
-
-#include <cstring>
-
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/net.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-#include "caffe/optimization/solver.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-class MNISTSolverTest : public ::testing::Test {};
-
-typedef ::testing::Types<float> Dtypes;
-TYPED_TEST_CASE(MNISTSolverTest, Dtypes);
-
-TYPED_TEST(MNISTSolverTest, TestSolve) {
- Caffe::set_mode(Caffe::GPU);
-
- NetParameter net_param;
- ReadProtoFromTextFile("caffe/test/data/lenet.prototxt",
- &net_param);
- vector<Blob<TypeParam>*> bottom_vec;
- Net<TypeParam> caffe_net(net_param, bottom_vec);
-
- // Run the network without training.
- LOG(ERROR) << "Performing Forward";
- caffe_net.Forward(bottom_vec);
- LOG(ERROR) << "Performing Backward";
- LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
-
- SolverParameter solver_param;
- solver_param.set_base_lr(0.01);
- solver_param.set_display(0);
- solver_param.set_max_iter(6000);
- solver_param.set_lr_policy("inv");
- solver_param.set_gamma(0.0001);
- solver_param.set_power(0.75);
- solver_param.set_momentum(0.9);
-
- LOG(ERROR) << "Starting Optimization";
- SGDSolver<TypeParam> solver(solver_param);
- solver.Solve(&caffe_net);
- LOG(ERROR) << "Optimization Done.";
-
- // Run the network after training.
- LOG(ERROR) << "Performing Forward";
- caffe_net.Forward(bottom_vec);
- LOG(ERROR) << "Performing Backward";
- TypeParam loss = caffe_net.Backward();
- LOG(ERROR) << "Final loss: " << loss;
- EXPECT_LE(loss, 0.5);
-
- NetParameter trained_net_param;
- caffe_net.ToProto(&trained_net_param);
- // LOG(ERROR) << "Writing to disk.";
- // WriteProtoToBinaryFile(trained_net_param,
- // "caffe/test/data/lenet_trained.prototxt");
-
- NetParameter traintest_net_param;
- ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
- &traintest_net_param);
- Net<TypeParam> caffe_traintest_net(traintest_net_param, bottom_vec);
- caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
-
- // Test run
- double train_accuracy = 0;
- int batch_size = traintest_net_param.layers(0).layer().batchsize();
- for (int i = 0; i < 60000 / batch_size; ++i) {
- const vector<Blob<TypeParam>*>& result =
- caffe_traintest_net.Forward(bottom_vec);
- train_accuracy += result[0]->cpu_data()[0];
- }
- train_accuracy /= 60000 / batch_size;
- LOG(ERROR) << "Train accuracy:" << train_accuracy;
- EXPECT_GE(train_accuracy, 0.98);
-
- NetParameter test_net_param;
- ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
- Net<TypeParam> caffe_test_net(test_net_param, bottom_vec);
- caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
-
- // Test run
- double test_accuracy = 0;
- batch_size = test_net_param.layers(0).layer().batchsize();
- for (int i = 0; i < 10000 / batch_size; ++i) {
- const vector<Blob<TypeParam>*>& result =
- caffe_test_net.Forward(bottom_vec);
- test_accuracy += result[0]->cpu_data()[0];
- }
- test_accuracy /= 10000 / batch_size;
- LOG(ERROR) << "Test accuracy:" << test_accuracy;
- EXPECT_GE(test_accuracy, 0.98);
-}
-
-} // namespace caffe
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 0d4f9bb1ed4c087b1c581636d3c1b8fe7591e47c..b7a830bbcb507bbecdb9670750ee3e100bb8d343 100644 (file)
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
}
}
+void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
+ Mat cv_img;
+ cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+ CHECK(cv_img.data) << "Could not open or find the image.";
+ DCHECK_EQ(cv_img.channels(), 3);
+ datum->set_channels(3);
+ datum->set_height(cv_img.rows);
+ datum->set_width(cv_img.cols);
+ datum->set_label(label);
+ datum->clear_data();
+ datum->clear_float_data();
+ string* datum_string = datum->mutable_data();
+ for (int c = 0; c < 3; ++c) {
+ for (int h = 0; h < cv_img.rows; ++h) {
+ for (int w = 0; w < cv_img.cols; ++w) {
+ datum_string->push_back(static_cast<char>(cv_img.at<Vec3b>(h, w)[c]));
+ }
+ }
+ }
+}
+
+
void WriteProtoToImage(const string& filename, const BlobProto& proto) {
CHECK_EQ(proto.num(), 1);
CHECK(proto.channels() == 3 || proto.channels() == 1);
diff --git a/src/caffe/util/io.hpp b/src/caffe/util/io.hpp
index 57beef1dc00685f06294ff1e813b70f1cbc075f5..ab4593668b473adc311af0468e0728bd3940828d 100644 (file)
--- a/src/caffe/util/io.hpp
+++ b/src/caffe/util/io.hpp
WriteProtoToImage(filename, proto);
}
+void ReadImageToDatum(const string& filename, const int label, Datum* datum);
+
void ReadProtoFromTextFile(const char* filename,
Message* proto);
inline void ReadProtoFromTextFile(const string& filename,
index 3aa43b2ba81afcf584e83a4cd2aeea652180bcd9..74c597825eb0514ee7ed5f8ab171a194c51925c5 100644 (file)
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
- // virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
- // vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
- // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
- // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
// sum_multiplier is just used to carry out sum using blas
Blob<Dtype> sum_multiplier_;
diff --git a/src/programs/convert_dataset.cpp b/src/programs/convert_dataset.cpp
--- /dev/null
@@ -0,0 +1,66 @@
+// Copyright 2013 Yangqing Jia
+// This program converts a set of images to a leveldb by storing them as Datum
+// proto buffers.
+// Usage:
+// convert_dataset ROOTFOLDER LISTFILE DB_NAME
+// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
+// should be a list of files as well as their labels, in the format as
+// subfolder1/file1.JPEG 0
+// ....
+
+#include <glog/logging.h>
+#include <leveldb/db.h>
+
+#include <string>
+#include <iostream>
+#include <fstream>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+
+using namespace caffe;
+using std::string;
+
+// A utility function to generate random strings
+void GenerateRandomPrefix(const int n, string* key) {
+ const char* kCHARS = "abcdefghijklmnopqrstuvwxyz";
+ key->clear();
+ for (int i = 0; i < n; ++i) {
+ key->push_back(kCHARS[rand() % 26]);
+ }
+ key->push_back('_');
+}
+
+int main(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
+ std::ifstream infile(argv[2]);
+ leveldb::DB* db;
+ leveldb::Options options;
+ options.error_if_exists = true;
+ options.create_if_missing = true;
+ LOG(INFO) << "Opening leveldb " << argv[3];
+ leveldb::Status status = leveldb::DB::Open(
+ options, argv[3], &db);
+ CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
+
+ string root_folder(argv[1]);
+ string filename;
+ int label;
+ Datum datum;
+ string key;
+ string value;
+ while (infile >> filename >> label) {
+ ReadImageToDatum(root_folder + filename, label, &datum);
+ // get the key, and add a random string so the leveldb will have permuted
+ // data
+ GenerateRandomPrefix(8, &key);
+ key += filename;
+ // get the value
+ datum.SerializeToString(&value);
+ db->Put(leveldb::WriteOptions(), key, value);
+ LOG(ERROR) << "Writing " << key;
+ }
+
+ delete db;
+ return 0;
+}