summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 0484890)
raw | patch | inline | side by side (parent: 0484890)
author | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 27 Sep 2013 17:25:23 +0000 (10:25 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 27 Sep 2013 17:25:23 +0000 (10:25 -0700) |
src/caffe/filler.hpp | patch | blob | history | |
src/caffe/test/lenet.hpp | patch | blob | history | |
src/caffe/test/test_net_proto.cpp | patch | blob | history |
diff --git a/src/caffe/filler.hpp b/src/caffe/filler.hpp
index f4554600ae23f92b14b54db4129f17da31df8af1..ffe7a5065593946cb4bb53d5b254cbbdea83f861 100644 (file)
--- a/src/caffe/filler.hpp
+++ b/src/caffe/filler.hpp
explicit UniformFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
- DCHECK(blob->count());
+ CHECK(blob->count());
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.min()),
Dtype(this->filler_param_.max()));
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
Dtype* data = blob->mutable_cpu_data();
- DCHECK(blob->count());
+ CHECK(blob->count());
caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
Dtype(this->filler_param_.mean()),
Dtype(this->filler_param_.std()));
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
int dim = blob->count() / blob->num();
- DCHECK(dim);
+ CHECK(dim);
for (int i = 0; i < blob->num(); ++i) {
Dtype sum = 0;
for (int j = 0; j < dim; ++j) {
}
};
+// A filler based on the paper [Bengio and Glorot 2010]: Understanding
+// the difficulty of training deep feedforward neuralnetworks, but does not
+// use the fan_out value.
+//
+// It fills the incoming matrix by randomly sampling uniform data from
+// [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
+// of input nodes, and in our case we consider the blob width as the scale.
+// You should make sure the input blob has shape (1, 1, height, width).
template <typename Dtype>
class XavierFiller : public Filler<Dtype> {
public:
explicit XavierFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
-
+ CHECK(blob->count());
+ CHECK_EQ(blob->num(), 1) << "XavierFiller requires blob.num() = 1.";
+ CHECK_EQ(blob->channels(), 1)
+ << "XavierFiller requires blob.channels() = 1.";
+ int fan_in = blob->width();
+ Dtype scale = sqrt(Dtype(3) / fan_in);
+ caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+ -scale, scale);
}
};
const std::string& type = param.type();
if (type == "constant") {
return new ConstantFiller<Dtype>(param);
- } else if (type == "uniform") {
- return new UniformFiller<Dtype>(param);
} else if (type == "gaussian") {
return new GaussianFiller<Dtype>(param);
} else if (type == "positive_unitball") {
return new PositiveUnitballFiller<Dtype>(param);
+ } else if (type == "uniform") {
+ return new UniformFiller<Dtype>(param);
+ } else if (type == "xavier") {
+ return new XavierFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
index 29ec3c47fac5be79f27f0e4a45856d0863b845bf..266f0b2ff2e527935affa425068ac90d91bfd804 100644 (file)
--- a/src/caffe/test/lenet.hpp
+++ b/src/caffe/test/lenet.hpp
num_output: 20\n\
kernelsize: 5\n\
stride: 1\n\
+ weight_filler {\n\
+ type: \"xavier\"\n\
+ }\n\
+ bias_filler {\n\
+ type: \"constant\"\n\
+ }\n\
}\n\
bottom: \"data\"\n\
top: \"conv1\"\n\
num_output: 50\n\
kernelsize: 5\n\
stride: 1\n\
+ weight_filler {\n\
+ type: \"xavier\"\n\
+ }\n\
+ bias_filler {\n\
+ type: \"constant\"\n\
+ }\n\
}\n\
bottom: \"pool1\"\n\
top: \"conv2\"\n\
name: \"ip1\"\n\
type: \"innerproduct\"\n\
num_output: 500\n\
+ weight_filler {\n\
+ type: \"xavier\"\n\
+ }\n\
+ bias_filler {\n\
+ type: \"constant\"\n\
+ }\n\
}\n\
bottom: \"pool2\"\n\
top: \"ip1\"\n\
name: \"ip2\"\n\
type: \"innerproduct\"\n\
num_output: 10\n\
+ weight_filler {\n\
+ type: \"xavier\"\n\
+ }\n\
+ bias_filler {\n\
+ type: \"constant\"\n\
+ }\n\
}\n\
bottom: \"relu1\"\n\
top: \"ip2\"\n\
index f0b0e7d9c0d002a1a2f30e005d4b27f96c72ce4f..013bd67ae51bb81edf811e78d30ec86514d4b2df 100644 (file)
shared_ptr<Filler<TypeParam> > filler;
filler.reset(new ConstantFiller<TypeParam>(filler_param));
filler->Fill(label.get());
- filler.reset(new GaussianFiller<TypeParam>(filler_param));
+ filler.reset(new UniformFiller<TypeParam>(filler_param));
filler->Fill(data.get());
vector<Blob<TypeParam>*> bottom_vec;