1 // Copyright Yangqing Jia 2013
2 //
3 // This script converts the MNIST dataset to the leveldb format used
4 // by caffe to perform classification.
5 // Usage:
6 // convert_mnist_data input_image_file input_label_file output_db_file
7 // The MNIST dataset could be downloaded at
8 // http://yann.lecun.com/exdb/mnist/
10 #include <google/protobuf/text_format.h>
11 #include <glog/logging.h>
12 #include <leveldb/db.h>
14 #include <stdint.h>
15 #include <iostream>
16 #include <fstream>
17 #include <string>
19 #include "caffe/proto/caffe.pb.h"
21 using std::string;
24 const int kCIFAR_SIZE=32;
25 const int kCIFAR_IMAGE_NBYTES=3072;
26 const int kCIFAR_BATCHSIZE=10000;
27 const int kCIFAR_TRAIN_BATCHES=5;
29 void read_image(std::ifstream& file, int* label, char* buffer) {
30 char label_char;
31 file.read(&label_char, 1);
32 *label = label_char;
33 file.read(buffer, kCIFAR_IMAGE_NBYTES);
34 return;
35 }
37 void convert_dataset(const string& input_folder, const string& output_folder) {
38 // Leveldb options
39 leveldb::Options options;
40 options.create_if_missing = true;
41 options.error_if_exists = true;
42 // Data buffer
43 int label;
44 char str_buffer[kCIFAR_IMAGE_NBYTES];
45 string value;
46 caffe::Datum datum;
47 datum.set_channels(3);
48 datum.set_height(kCIFAR_SIZE);
49 datum.set_width(kCIFAR_SIZE);
51 LOG(INFO) << "Writing Training data";
52 leveldb::DB* train_db;
53 leveldb::Status status;
54 status = leveldb::DB::Open(options, output_folder + "/cifar-train-leveldb",
55 &train_db);
56 CHECK(status.ok()) << "Failed to open leveldb.";
57 for (int fileid = 0; fileid < kCIFAR_TRAIN_BATCHES; ++fileid) {
58 // Open files
59 LOG(INFO) << "Training Batch " << fileid + 1;
60 sprintf(str_buffer, "/data_batch_%d.bin", fileid + 1);
61 std::ifstream data_file((input_folder + str_buffer).c_str(),
62 std::ios::in | std::ios::binary);
63 CHECK(data_file) << "Unable to open train file #" << fileid + 1;
64 for (int itemid = 0; itemid < kCIFAR_BATCHSIZE; ++itemid) {
65 read_image(data_file, &label, str_buffer);
66 datum.set_label(label);
67 datum.set_data(str_buffer, kCIFAR_IMAGE_NBYTES);
68 datum.SerializeToString(&value);
69 sprintf(str_buffer, "%05d", fileid * kCIFAR_BATCHSIZE + itemid);
70 train_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
71 }
72 }
74 LOG(INFO) << "Writing Testing data";
75 leveldb::DB* test_db;
76 CHECK(leveldb::DB::Open(options, output_folder + "/cifar-test-leveldb",
77 &test_db).ok()) << "Failed to open leveldb.";
78 // Open files
79 std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
80 std::ios::in | std::ios::binary);
81 CHECK(data_file) << "Unable to open test file.";
82 for (int itemid = 0; itemid < kCIFAR_BATCHSIZE; ++itemid) {
83 read_image(data_file, &label, str_buffer);
84 datum.set_label(label);
85 datum.set_data(str_buffer, kCIFAR_IMAGE_NBYTES);
86 datum.SerializeToString(&value);
87 sprintf(str_buffer, "%05d", itemid);
88 test_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
89 }
91 delete train_db;
92 delete test_db;
93 }
95 int main (int argc, char** argv) {
96 if (argc != 3) {
97 printf("This script converts the CIFAR dataset to the leveldb format used\n"
98 "by caffe to perform classification.\n"
99 "Usage:\n"
100 " convert_cifar_data input_folder output_folder\n"
101 "Where the input folder should contain the binary batch files.\n"
102 "The CIFAR dataset could be downloaded at\n"
103 " http://www.cs.toronto.edu/~kriz/cifar.html\n"
104 "You should gunzip them after downloading.\n");
105 } else {
106 google::InitGoogleLogging(argv[0]);
107 convert_dataset(string(argv[1]), string(argv[2]));
108 }
109 return 0;
110 }