12fd6d94625e68e82188951e208c8af728ae85c3
1 // Copyright 2013 Yangqing Jia
3 #include <stdint.h>
4 #include <leveldb/db.h>
5 #include <pthread.h>
7 #include <string>
8 #include <vector>
10 #include "caffe/layer.hpp"
11 #include "caffe/util/io.hpp"
12 #include "caffe/vision_layers.hpp"
14 using std::string;
16 namespace caffe {
18 template <typename Dtype>
19 void* DataLayerPrefetch(void* layer_pointer) {
20 DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
21 Datum datum;
22 Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
23 Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
24 const Dtype scale = layer->layer_param_.scale();
25 const int batchsize = layer->layer_param_.batchsize();
26 const int cropsize = layer->layer_param_.cropsize();
27 const bool mirror = layer->layer_param_.mirror();
29 if (mirror && cropsize == 0) {
30 LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
31 << "set at the same time.";
32 }
33 // datum scales
34 const int channels = layer->datum_channels_;
35 const int height = layer->datum_height_;
36 const int width = layer->datum_width_;
37 const int size = layer->datum_size_;
38 const Dtype* mean = layer->data_mean_.cpu_data();
39 for (int itemid = 0; itemid < batchsize; ++itemid) {
40 // get a blob
41 datum.ParseFromString(layer->iter_->value().ToString());
42 const string& data = datum.data();
43 if (cropsize) {
44 CHECK(data.size()) << "Image cropping only support uint8 data";
45 int h_off, w_off;
46 // We only do random crop when we do training.
47 if (Caffe::phase() == Caffe::TRAIN) {
48 h_off = rand() % (height - cropsize);
49 w_off = rand() % (width - cropsize);
50 } else {
51 h_off = (height - cropsize) / 2;
52 w_off = (width - cropsize) / 2;
53 }
54 if (mirror && rand() % 2) {
55 // Copy mirrored version
56 for (int c = 0; c < channels; ++c) {
57 for (int h = 0; h < cropsize; ++h) {
58 for (int w = 0; w < cropsize; ++w) {
59 top_data[((itemid * channels + c) * cropsize + h) * cropsize
60 + cropsize - 1 - w] =
61 (static_cast<Dtype>(
62 (uint8_t)data[(c * height + h + h_off) * width
63 + w + w_off])
64 - mean[(c * height + h + h_off) * width + w + w_off])
65 * scale;
66 }
67 }
68 }
69 } else {
70 // Normal copy
71 for (int c = 0; c < channels; ++c) {
72 for (int h = 0; h < cropsize; ++h) {
73 for (int w = 0; w < cropsize; ++w) {
74 top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
75 = (static_cast<Dtype>(
76 (uint8_t)data[(c * height + h + h_off) * width
77 + w + w_off])
78 - mean[(c * height + h + h_off) * width + w + w_off])
79 * scale;
80 }
81 }
82 }
83 }
84 } else {
85 // we will prefer to use data() first, and then try float_data()
86 if (data.size()) {
87 for (int j = 0; j < size; ++j) {
88 top_data[itemid * size + j] =
89 (static_cast<Dtype>((uint8_t)data[j]) - mean[j]) * scale;
90 }
91 } else {
92 for (int j = 0; j < size; ++j) {
93 top_data[itemid * size + j] =
94 (datum.float_data(j) - mean[j]) * scale;
95 }
96 }
97 }
99 top_label[itemid] = datum.label();
100 // go to the next iter
101 layer->iter_->Next();
102 if (!layer->iter_->Valid()) {
103 // We have reached the end. Restart from the first.
104 DLOG(INFO) << "Restarting data prefetching from start.";
105 layer->iter_->SeekToFirst();
106 }
107 }
109 return (void*)NULL;
110 }
113 template <typename Dtype>
114 void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
115 vector<Blob<Dtype>*>* top) {
116 CHECK_EQ(bottom.size(), 0) << "Neuron Layer takes no input blobs.";
117 CHECK_EQ(top->size(), 2) << "Neuron Layer takes two blobs as output.";
118 // Initialize the leveldb
119 leveldb::DB* db_temp;
120 leveldb::Options options;
121 options.create_if_missing = false;
122 LOG(INFO) << "Opening leveldb " << this->layer_param_.source();
123 leveldb::Status status = leveldb::DB::Open(
124 options, this->layer_param_.source(), &db_temp);
125 CHECK(status.ok()) << "Failed to open leveldb "
126 << this->layer_param_.source();
127 db_.reset(db_temp);
128 iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
129 iter_->SeekToFirst();
130 // Read a data point, and use it to initialize the top blob.
131 Datum datum;
132 datum.ParseFromString(iter_->value().ToString());
133 // image
134 int cropsize = this->layer_param_.cropsize();
135 if (cropsize > 0) {
136 (*top)[0]->Reshape(
137 this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
138 prefetch_data_.reset(new Blob<Dtype>(
139 this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
140 } else {
141 (*top)[0]->Reshape(
142 this->layer_param_.batchsize(), datum.channels(), datum.height(),
143 datum.width());
144 prefetch_data_.reset(new Blob<Dtype>(
145 this->layer_param_.batchsize(), datum.channels(), datum.height(),
146 datum.width()));
147 }
148 LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
149 << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
150 << (*top)[0]->width();
151 // label
152 (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
153 prefetch_label_.reset(
154 new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
155 // datum size
156 datum_channels_ = datum.channels();
157 datum_height_ = datum.height();
158 datum_width_ = datum.width();
159 datum_size_ = datum.channels() * datum.height() * datum.width();
160 CHECK_GT(datum_height_, cropsize);
161 CHECK_GT(datum_width_, cropsize);
162 // check if we want to have mean
163 if (this->layer_param_.has_meanfile()) {
164 BlobProto blob_proto;
165 LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
166 ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
167 data_mean_.FromProto(blob_proto);
168 CHECK_EQ(data_mean_.num(), 1);
169 CHECK_EQ(data_mean_.channels(), datum_channels_);
170 CHECK_EQ(data_mean_.height(), datum_height_);
171 CHECK_EQ(data_mean_.width(), datum_width_);
172 } else {
173 // Simply initialize an all-empty mean.
174 data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
175 }
176 // Now, start the prefetch thread. Before calling prefetch, we make two
177 // cpu_data calls so that the prefetch thread does not accidentally make
178 // simultaneous cudaMalloc calls when the main thread is running. In some
179 // GPUs this seems to cause failures if we do not so.
180 prefetch_data_->mutable_cpu_data();
181 prefetch_label_->mutable_cpu_data();
182 data_mean_.cpu_data();
183 DLOG(INFO) << "Initializing prefetch";
184 CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
185 reinterpret_cast<void*>(this))) << "Pthread execution failed.";
186 DLOG(INFO) << "Prefetch initialized.";
187 }
189 template <typename Dtype>
190 void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
191 vector<Blob<Dtype>*>* top) {
192 // First, join the thread
193 CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
194 // Copy the data
195 memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
196 sizeof(Dtype) * prefetch_data_->count());
197 memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
198 sizeof(Dtype) * prefetch_label_->count());
199 // Start a new prefetch thread
200 CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
201 reinterpret_cast<void*>(this))) << "Pthread execution failed.";
202 }
204 template <typename Dtype>
205 void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
206 vector<Blob<Dtype>*>* top) {
207 // First, join the thread
208 CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
209 // Copy the data
210 CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
211 prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
212 cudaMemcpyHostToDevice));
213 CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
214 prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
215 cudaMemcpyHostToDevice));
216 // Start a new prefetch thread
217 CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
218 reinterpret_cast<void*>(this))) << "Pthread execution failed.";
219 }
221 // The backward operations are dummy - they do not carry any computation.
222 template <typename Dtype>
223 Dtype DataLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
224 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
225 return Dtype(0.);
226 }
228 template <typename Dtype>
229 Dtype DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
230 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
231 return Dtype(0.);
232 }
234 INSTANTIATE_CLASS(DataLayer);
236 } // namespace caffe