author | Yangqing Jia <jiayq84@gmail.com> | |
Wed, 30 Oct 2013 18:20:51 +0000 (11:20 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Wed, 30 Oct 2013 18:20:51 +0000 (11:20 -0700) |
index c0d410ad9b3d649731b78b12870f2ca386983fde..6b5ee78aee54a50d93dbaf19f7fa97c6f404038f 100644 (file)
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
+ if (argc < 4) {
+ LOG(ERROR) << "Usage: convert_imageset ROOTFOLDER LISTFILE DB_NAME [0/1]";
+ return 0;
+ }
std::ifstream infile(argv[2]);
std::vector<std::pair<string, int> > lines;
string filename;
std::random_shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << "images.";
-
+
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
index a57badfc005269f3dba1992b6b19ecc6a9429ef7..2e24beffacf89525c447cb70ba718f3e5f1cfc5e 100644 (file)
int WIDTH_;
int POOLED_HEIGHT_;
int POOLED_WIDTH_;
+ Blob<float> rand_idx_;
};
index 59ce3fe71e4486ad2a6656d20cccb397789151c7..6141642155d231796fcf89dc7ca0c7a845100fdf 100644 (file)
ceil(static_cast<float>(WIDTH_ - KSIZE_) / STRIDE_)) + 1;
(*top)[0]->Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
POOLED_WIDTH_);
+ // If stochastic pooling, we will initialize the random index part.
+ if (this->layer_param_.pool() == LayerParameter_PoolMethod_STOCHASTIC) {
+ rand_idx_.Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
+ POOLED_WIDTH_);
+ }
};
// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
}
}
break;
+ case LayerParameter_PoolMethod_STOCHASTIC:
+ NOT_IMPLEMENTED;
+ break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
}
}
break;
+ case LayerParameter_PoolMethod_STOCHASTIC:
+ NOT_IMPLEMENTED;
+ break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
index 9d15c534f179b2cfc9a97398d2b6eb46b6df635d..1cbb4abe8a270c71290ed79648dacfbb32bad7eb 100644 (file)
} // (if index < nthreads)
}
+template <typename Dtype>
+__global__ void StoPoolForwardTrain(const int nthreads,
+ const Dtype* bottom_data,
+ const int num, const int channels, const int height,
+ const int width, const int pooled_height, const int pooled_width,
+ const int ksize, const int stride, float* rand_idx, Dtype* top_data) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < nthreads) {
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ int hstart = ph * stride;
+ int hend = min(hstart + ksize, height);
+ int wstart = pw * stride;
+ int wend = min(wstart + ksize, width);
+ Dtype cumsum = 0.;
+ bottom_data += (n * channels + c) * height * width;
+ // First pass: get sum
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ cumsum += bottom_data[h * width + w];
+ }
+ }
+ float thres = rand_idx[index] * cumsum;
+ // Second pass: get value, and set index.
+ cumsum = 0;
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ cumsum += bottom_data[h * width + w];
+ if (cumsum >= thres) {
+ rand_idx[index] = ((n * channels + c) * height + h) * width + w;
+ top_data[index] = bottom_data[h * width + w];
+ return;
+ }
+ }
+ }
+ } // (if index < nthreads)
+}
+
+
+template <typename Dtype>
+__global__ void StoPoolForwardTest(const int nthreads,
+ const Dtype* bottom_data,
+ const int num, const int channels, const int height,
+ const int width, const int pooled_height, const int pooled_width,
+ const int ksize, const int stride, Dtype* top_data) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < nthreads) {
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ int hstart = ph * stride;
+ int hend = min(hstart + ksize, height);
+ int wstart = pw * stride;
+ int wend = min(wstart + ksize, width);
+ Dtype cumsum = 0.;
+ Dtype cumvalues = 0.;
+ bottom_data += (n * channels + c) * height * width;
+ // First pass: get sum
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ cumsum += bottom_data[h * width + w];
+ cumvalues += bottom_data[h * width + w] * bottom_data[h * width + w];
+ }
+ }
+ top_data[index] = cumvalues / cumsum;
+ } // (if index < nthreads)
+}
+
+
template <typename Dtype>
void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
top_data);
break;
+ case LayerParameter_PoolMethod_STOCHASTIC:
+ if (Caffe::phase() == Caffe::TRAIN) {
+ // We need to create the random index as well.
+ CURAND_CHECK(curandGenerateUniform(Caffe::curand_generator(),
+ rand_idx_.mutable_gpu_data(), count));
+ StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, bottom[0]->num(), CHANNELS_,
+ HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
+ rand_idx_.mutable_gpu_data(), top_data);
+ } else {
+ StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, bottom[0]->num(), CHANNELS_,
+ HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
+ top_data);
+ }
+ break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
} // (if index < nthreads)
}
+
+template <typename Dtype>
+__global__ void StoPoolBackward(const int nthreads,
+ const float* rand_idx, const Dtype* top_diff,
+ const int num, const int channels, const int height,
+ const int width, const int pooled_height, const int pooled_width,
+ const int ksize, const int stride, Dtype* bottom_diff) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < nthreads) {
+ // find out the local index
+ // find out the local offset
+ int w = index % width;
+ int h = (index / width) % height;
+ int c = (index / width / height) % channels;
+ int n = index / width / height / channels;
+ int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+ int phend = min(h / stride + 1, pooled_height);
+ int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+ int pwend = min(w / stride + 1, pooled_width);
+ Dtype gradient = 0;
+ rand_idx += (n * channels + c) * pooled_height * pooled_width;
+ top_diff += (n * channels + c) * pooled_height * pooled_width;
+ for (int ph = phstart; ph < phend; ++ph) {
+ for (int pw = pwstart; pw < pwend; ++pw) {
+ gradient += top_diff[ph * pooled_width + pw] *
+ (index == int(rand_idx[ph * pooled_width + pw]));
+ }
+ }
+ bottom_diff[index] = gradient;
+ } // (if index < nthreads)
+}
+
+
template <typename Dtype>
Dtype PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
HEIGHT_, WIDTH_, POOLED_HEIGHT_, POOLED_WIDTH_, KSIZE_, STRIDE_,
bottom_diff);
break;
+ case LayerParameter_PoolMethod_STOCHASTIC:
+ StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, rand_idx_.gpu_data(), top_diff,
+ top[0]->num(), CHANNELS_, HEIGHT_, WIDTH_, POOLED_HEIGHT_,
+ POOLED_WIDTH_, KSIZE_, STRIDE_, bottom_diff);
+ break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
index 22716a1873153fae7334f310804e67aee16ed255..16bb352e964c0accfb9b385f65c2efcaa02c9b71 100644 (file)
enum PoolMethod {
MAX = 0;
AVE = 1;
+ STOCHASTIC = 2;
}
optional PoolMethod pool = 11 [default = MAX]; // The pooling method
optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio