]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/layers/data_layer.cpp
data layer race condition bugfix
[jacinto-ai/caffe-jacinto.git] / src / caffe / layers / data_layer.cpp
1 // Copyright 2013 Yangqing Jia
3 #include <stdint.h>
4 #include <leveldb/db.h>
5 #include <pthread.h>
7 #include <string>
8 #include <vector>
10 #include "caffe/layer.hpp"
11 #include "caffe/util/io.hpp"
12 #include "caffe/vision_layers.hpp"
14 using std::string;
16 namespace caffe {
18 template <typename Dtype>
19 void* DataLayerPrefetch(void* layer_pointer) {
20   DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
21   Datum datum;
22   Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
23   Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
24   const Dtype scale = layer->layer_param_.scale();
25   const int batchsize = layer->layer_param_.batchsize();
26   const int cropsize = layer->layer_param_.cropsize();
27   const bool mirror = layer->layer_param_.mirror();
29   if (mirror && cropsize == 0) {
30     LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
31         << "set at the same time.";
32   }
33   // datum scales
34   const int channels = layer->datum_channels_;
35   const int height = layer->datum_height_;
36   const int width = layer->datum_width_;
37   const int size = layer->datum_size_;
38   const Dtype* mean = layer->data_mean_.cpu_data();
39   for (int itemid = 0; itemid < batchsize; ++itemid) {
40     // get a blob
41     datum.ParseFromString(layer->iter_->value().ToString());
42     const string& data = datum.data();
43     if (cropsize) {
44       CHECK(data.size()) << "Image cropping only support uint8 data";
45       int h_off = rand() % (height - cropsize);
46       int w_off = rand() % (width - cropsize);
47       if (mirror && rand() % 2) {
48         // Copy mirrored version
49         for (int c = 0; c < channels; ++c) {
50           for (int h = 0; h < cropsize; ++h) {
51             for (int w = 0; w < cropsize; ++w) {
52               top_data[((itemid * channels + c) * cropsize + h) * cropsize
53                        + cropsize - 1 - w] =
54                   (static_cast<Dtype>(
55                       (uint8_t)data[(c * height + h + h_off) * width
56                                     + w + w_off])
57                     - mean[(c * height + h + h_off) * width + w + w_off])
58                   * scale;
59             }
60           }
61         }
62       } else {
63         // Normal copy
64         for (int c = 0; c < channels; ++c) {
65           for (int h = 0; h < cropsize; ++h) {
66             for (int w = 0; w < cropsize; ++w) {
67               top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
68                   = (static_cast<Dtype>(
69                       (uint8_t)data[(c * height + h + h_off) * width
70                                     + w + w_off])
71                      - mean[(c * height + h + h_off) * width + w + w_off])
72                   * scale;
73             }
74           }
75         }
76       }
77     } else {
78       // we will prefer to use data() first, and then try float_data()
79       if (data.size()) {
80         for (int j = 0; j < size; ++j) {
81           top_data[itemid * size + j] =
82               (static_cast<Dtype>((uint8_t)data[j]) - mean[j]) * scale;
83         }
84       } else {
85         for (int j = 0; j < size; ++j) {
86           top_data[itemid * size + j] =
87               (datum.float_data(j) - mean[j]) * scale;
88         }
89       }
90     }
92     top_label[itemid] = datum.label();
93     // go to the next iter
94     layer->iter_->Next();
95     if (!layer->iter_->Valid()) {
96       // We have reached the end. Restart from the first.
97       LOG(INFO) << "Restarting data read from start.";
98       layer->iter_->SeekToFirst();
99     }
100   }
102   return (void*)NULL;
106 template <typename Dtype>
107 void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
108       vector<Blob<Dtype>*>* top) {
109   CHECK_EQ(bottom.size(), 0) << "Neuron Layer takes no input blobs.";
110   CHECK_EQ(top->size(), 2) << "Neuron Layer takes two blobs as output.";
111   // Initialize the leveldb
112   leveldb::DB* db_temp;
113   leveldb::Options options;
114   options.create_if_missing = false;
115   LOG(INFO) << "Opening leveldb " << this->layer_param_.source();
116   leveldb::Status status = leveldb::DB::Open(
117       options, this->layer_param_.source(), &db_temp);
118   CHECK(status.ok()) << "Failed to open leveldb "
119       << this->layer_param_.source();
120   db_.reset(db_temp);
121   iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
122   iter_->SeekToFirst();
123   // Read a data point, and use it to initialize the top blob.
124   Datum datum;
125   datum.ParseFromString(iter_->value().ToString());
126   // image
127   int cropsize = this->layer_param_.cropsize();
128   if (cropsize > 0) {
129     (*top)[0]->Reshape(
130         this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
131     prefetch_data_.reset(new Blob<Dtype>(
132         this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
133   } else {
134     (*top)[0]->Reshape(
135         this->layer_param_.batchsize(), datum.channels(), datum.height(),
136         datum.width());
137     prefetch_data_.reset(new Blob<Dtype>(
138         this->layer_param_.batchsize(), datum.channels(), datum.height(),
139         datum.width()));
140   }
141   LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
142       << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
143       << (*top)[0]->width();
144   // label
145   (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
146   prefetch_label_.reset(
147       new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
148   // datum size
149   datum_channels_ = datum.channels();
150   datum_height_ = datum.height();
151   datum_width_ = datum.width();
152   datum_size_ = datum.channels() * datum.height() * datum.width();
153   CHECK_GT(datum_height_, cropsize);
154   CHECK_GT(datum_width_, cropsize);
155   // check if we want to have mean
156   if (this->layer_param_.has_meanfile()) {
157     BlobProto blob_proto;
158     LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
159     ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
160     data_mean_.FromProto(blob_proto);
161     CHECK_EQ(data_mean_.num(), 1);
162     CHECK_EQ(data_mean_.channels(), datum_channels_);
163     CHECK_EQ(data_mean_.height(), datum_height_);
164     CHECK_EQ(data_mean_.width(), datum_width_);
165   } else {
166     // Simply initialize an all-empty mean.
167     data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
168   }
169   // Now, start the prefetch thread. Before calling prefetch, we make two
170   // cpu_data calls so that the prefetch thread does not accidentally make
171   // simultaneous cudaMalloc calls when the main thread is running. In some
172   // GPUs this seems to cause failures if we do not so.
173   prefetch_data_->mutable_cpu_data();
174   prefetch_label_->mutable_cpu_data();
175   data_mean_.cpu_data();
176   // LOG(INFO) << "Initializing prefetch";
177   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
178       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
179   // LOG(INFO) << "Prefetch initialized.";
182 template <typename Dtype>
183 void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
184       vector<Blob<Dtype>*>* top) {
185   // First, join the thread
186   CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
187   // Copy the data
188   memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
189       sizeof(Dtype) * prefetch_data_->count());
190   memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
191       sizeof(Dtype) * prefetch_label_->count());
192   // Start a new prefetch thread
193   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
194       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
197 template <typename Dtype>
198 void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
199       vector<Blob<Dtype>*>* top) {
200   // First, join the thread
201   CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
202   // Copy the data
203   CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
204       prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
205       cudaMemcpyHostToDevice));
206   CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
207       prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
208       cudaMemcpyHostToDevice));
209   // Start a new prefetch thread
210   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
211       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
214 // The backward operations are dummy - they do not carry any computation.
215 template <typename Dtype>
216 Dtype DataLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
217       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
218   return Dtype(0.);
221 template <typename Dtype>
222 Dtype DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
223       const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
224   return Dtype(0.);
227 INSTANTIATE_CLASS(DataLayer);
229 }  // namespace caffe