index d42a810410651ea06b826f574c713e55ae7c4458..12fd6d94625e68e82188951e208c8af728ae85c3 100644 (file)
#include <stdint.h>
#include <leveldb/db.h>
+#include <pthread.h>
#include <string>
#include <vector>
#include "caffe/layer.hpp"
+#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
using std::string;
namespace caffe {
+template <typename Dtype>
+void* DataLayerPrefetch(void* layer_pointer) {
+ DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
+ Datum datum;
+ Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
+ Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
+ const Dtype scale = layer->layer_param_.scale();
+ const int batchsize = layer->layer_param_.batchsize();
+ const int cropsize = layer->layer_param_.cropsize();
+ const bool mirror = layer->layer_param_.mirror();
+
+ if (mirror && cropsize == 0) {
+ LOG(FATAL) << "Current implementation requires mirror and cropsize to be "
+ << "set at the same time.";
+ }
+ // datum scales
+ const int channels = layer->datum_channels_;
+ const int height = layer->datum_height_;
+ const int width = layer->datum_width_;
+ const int size = layer->datum_size_;
+ const Dtype* mean = layer->data_mean_.cpu_data();
+ for (int itemid = 0; itemid < batchsize; ++itemid) {
+ // get a blob
+ datum.ParseFromString(layer->iter_->value().ToString());
+ const string& data = datum.data();
+ if (cropsize) {
+ CHECK(data.size()) << "Image cropping only support uint8 data";
+ int h_off, w_off;
+ // We only do random crop when we do training.
+ if (Caffe::phase() == Caffe::TRAIN) {
+ h_off = rand() % (height - cropsize);
+ w_off = rand() % (width - cropsize);
+ } else {
+ h_off = (height - cropsize) / 2;
+ w_off = (width - cropsize) / 2;
+ }
+ if (mirror && rand() % 2) {
+ // Copy mirrored version
+ for (int c = 0; c < channels; ++c) {
+ for (int h = 0; h < cropsize; ++h) {
+ for (int w = 0; w < cropsize; ++w) {
+ top_data[((itemid * channels + c) * cropsize + h) * cropsize
+ + cropsize - 1 - w] =
+ (static_cast<Dtype>(
+ (uint8_t)data[(c * height + h + h_off) * width
+ + w + w_off])
+ - mean[(c * height + h + h_off) * width + w + w_off])
+ * scale;
+ }
+ }
+ }
+ } else {
+ // Normal copy
+ for (int c = 0; c < channels; ++c) {
+ for (int h = 0; h < cropsize; ++h) {
+ for (int w = 0; w < cropsize; ++w) {
+ top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
+ = (static_cast<Dtype>(
+ (uint8_t)data[(c * height + h + h_off) * width
+ + w + w_off])
+ - mean[(c * height + h + h_off) * width + w + w_off])
+ * scale;
+ }
+ }
+ }
+ }
+ } else {
+ // we will prefer to use data() first, and then try float_data()
+ if (data.size()) {
+ for (int j = 0; j < size; ++j) {
+ top_data[itemid * size + j] =
+ (static_cast<Dtype>((uint8_t)data[j]) - mean[j]) * scale;
+ }
+ } else {
+ for (int j = 0; j < size; ++j) {
+ top_data[itemid * size + j] =
+ (datum.float_data(j) - mean[j]) * scale;
+ }
+ }
+ }
+
+ top_label[itemid] = datum.label();
+ // go to the next iter
+ layer->iter_->Next();
+ if (!layer->iter_->Valid()) {
+ // We have reached the end. Restart from the first.
+ DLOG(INFO) << "Restarting data prefetching from start.";
+ layer->iter_->SeekToFirst();
+ }
+ }
+
+ return (void*)NULL;
+}
+
+
template <typename Dtype>
void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Datum datum;
datum.ParseFromString(iter_->value().ToString());
// image
- (*top)[0]->Reshape(
- this->layer_param_.batchsize(), datum.channels(), datum.height(),
- datum.width());
+ int cropsize = this->layer_param_.cropsize();
+ if (cropsize > 0) {
+ (*top)[0]->Reshape(
+ this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
+ prefetch_data_.reset(new Blob<Dtype>(
+ this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
+ } else {
+ (*top)[0]->Reshape(
+ this->layer_param_.batchsize(), datum.channels(), datum.height(),
+ datum.width());
+ prefetch_data_.reset(new Blob<Dtype>(
+ this->layer_param_.batchsize(), datum.channels(), datum.height(),
+ datum.width()));
+ }
+ LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
+ << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
+ << (*top)[0]->width();
// label
(*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
+ prefetch_label_.reset(
+ new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
// datum size
- datum_size_ = datum.channels() * datum.height() * datum.width();
+ datum_channels_ = datum.channels();
+ datum_height_ = datum.height();
+ datum_width_ = datum.width();
+ datum_size_ = datum.channels() * datum.height() * datum.width();
+ CHECK_GT(datum_height_, cropsize);
+ CHECK_GT(datum_width_, cropsize);
+ // check if we want to have mean
+ if (this->layer_param_.has_meanfile()) {
+ BlobProto blob_proto;
+ LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
+ ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
+ data_mean_.FromProto(blob_proto);
+ CHECK_EQ(data_mean_.num(), 1);
+ CHECK_EQ(data_mean_.channels(), datum_channels_);
+ CHECK_EQ(data_mean_.height(), datum_height_);
+ CHECK_EQ(data_mean_.width(), datum_width_);
+ } else {
+ // Simply initialize an all-empty mean.
+ data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
+ }
+ // Now, start the prefetch thread. Before calling prefetch, we make two
+ // cpu_data calls so that the prefetch thread does not accidentally make
+ // simultaneous cudaMalloc calls when the main thread is running. In some
+ // GPUs this seems to cause failures if we do not so.
+ prefetch_data_->mutable_cpu_data();
+ prefetch_label_->mutable_cpu_data();
+ data_mean_.cpu_data();
+ DLOG(INFO) << "Initializing prefetch";
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
+ DLOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>
void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Datum datum;
- Dtype* top_data = (*top)[0]->mutable_cpu_data();
- Dtype* top_label = (*top)[1]->mutable_cpu_data();
- const Dtype scale = this->layer_param_.scale();
- const Dtype subtraction = this->layer_param_.subtraction();
- // LOG(ERROR) << "Debug code on";
- // if (true) {
- // iter_->SeekToFirst();
- // }
- for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
- // get a blob
- datum.ParseFromString(iter_->value().ToString());
- const string& data = datum.data();
- // we will prefer to use data() first, and then try float_data()
- if (data.size()) {
- for (int j = 0; j < datum_size_; ++j) {
- top_data[i * datum_size_ + j] =
- (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
- }
- } else {
- for (int j = 0; j < datum_size_; ++j) {
- top_data[i * datum_size_ + j] =
- (datum.float_data(j) * scale) - subtraction;
- }
- }
- top_label[i] = datum.label();
- // go to the next iter
- iter_->Next();
- if (!iter_->Valid()) {
- // We have reached the end. Restart from the first.
- LOG(INFO) << "Restarting data read from start.";
- iter_->SeekToFirst();
- }
- }
+ // First, join the thread
+ CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
+ // Copy the data
+ memcpy((*top)[0]->mutable_cpu_data(), prefetch_data_->cpu_data(),
+ sizeof(Dtype) * prefetch_data_->count());
+ memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
+ sizeof(Dtype) * prefetch_label_->count());
+ // Start a new prefetch thread
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
template <typename Dtype>
void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Forward_cpu(bottom, top);
- // explicitly copy data to gpu - this is achieved by simply calling gpu_data
- // functions.
- // TODO(Yangqing): maybe we don't need this since data synchronization is
- // simply done under the hood?
- (*top)[0]->gpu_data();
- (*top)[1]->gpu_data();
+ // First, join the thread
+ CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
+ // Copy the data
+ CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
+ prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
+ cudaMemcpyHostToDevice));
+ CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
+ prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
+ cudaMemcpyHostToDevice));
+ // Start a new prefetch thread
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
// The backward operations are dummy - they do not carry any computation.