summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (from parent 1: c4db406)
raw | patch | inline | side by side (from parent 1: c4db406)
author | Yangqing Jia <jiayq84@gmail.com> | |
Thu, 26 Sep 2013 23:08:40 +0000 (16:08 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Thu, 26 Sep 2013 23:08:40 +0000 (16:08 -0700) |
16 files changed:
src/caffe/blob.cpp | patch | blob | history | |
src/caffe/blob.hpp | patch | blob | history | |
src/caffe/common.hpp | patch | blob | history | |
src/caffe/layer.hpp | patch | blob | history | |
src/caffe/layer_factory.hpp | patch | blob | history | |
src/caffe/layers/conv_layer.cpp | patch | blob | history | |
src/caffe/layers/inner_product_layer.cpp | patch | blob | history | |
src/caffe/net.cpp | patch | blob | history | |
src/caffe/net.hpp | patch | blob | history | |
src/caffe/proto/lenet.prototxt | [deleted file] | patch | blob | history |
src/caffe/test/lenet.hpp | [new file with mode: 0644] | patch | blob |
src/caffe/test/test_blob.cpp | patch | blob | history | |
src/caffe/test/test_gradient_check_util.hpp | patch | blob | history | |
src/caffe/test/test_innerproduct_layer.cpp | patch | blob | history | |
src/caffe/test/test_net_proto.cpp | [new file with mode: 0644] | patch | blob |
src/caffe/test/test_pooling_layer.cpp | patch | blob | history |
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index ecb37b7f15ef7730d1d97ee52a5d778f1b967c3a..0a00ce5a539402cde34fc020e00fedc9a43afd25 100644 (file)
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
- CHECK_GT(num, 0);
- CHECK_GT(channels, 0);
- CHECK_GT(height, 0);
- CHECK_GT(width, 0);
+ CHECK_GE(num, 0);
+ CHECK_GE(channels, 0);
+ CHECK_GE(height, 0);
+ CHECK_GE(width, 0);
num_ = num;
channels_ = channels;
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
- data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ if (count_) {
+ data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ } else {
+ data_.reset((SyncedMemory*)NULL);
+ diff_.reset((SyncedMemory*)NULL);
+ }
}
template <typename Dtype>
Reshape(num, channels, height, width);
}
-template <typename Dtype>
-Blob<Dtype>::Blob(const Blob<Dtype>& source) {
- if (source.count() == 0) {
- Blob();
- } else {
- Reshape(source.num(), source.channels(), source.height(),
- source.width());
- if (count_ > 0) {
- // Copy the data.
- memcpy(data_->mutable_cpu_data(), source.cpu_data(),
- count_ * sizeof(Dtype));
- memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
- count_ * sizeof(Dtype));
- }
- }
-}
-
-template <typename Dtype>
-const Blob<Dtype>& Blob<Dtype>::operator=(const Blob<Dtype>& source) {
- Reshape(source.num(), source.channels(), source.height(),
- source.width());
- if (count_ > 0) {
- // Copy the data.
- memcpy(data_->mutable_cpu_data(), source.cpu_data(),
- count_ * sizeof(Dtype));
- memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
- count_ * sizeof(Dtype));
- }
- return (*this);
-}
-
template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
diff --git a/src/caffe/blob.hpp b/src/caffe/blob.hpp
index a14c046130799b80e554bd63555d49d26291bd10..39a2cf05cfa6033fa420d5ed855c3f95fdf494be 100644 (file)
--- a/src/caffe/blob.hpp
+++ b/src/caffe/blob.hpp
diff_() {}
explicit Blob(const int num, const int channels, const int height,
const int width);
- Blob(const Blob<Dtype>& source);
- const Blob<Dtype>& operator=(const Blob<Dtype>& src);
virtual ~Blob() {}
void Reshape(const int num, const int height,
const int width, const int channels);
int height_;
int width_;
int count_;
+
+ DISABLE_COPY_AND_ASSIGN(Blob);
}; // class Blob
} // namespace caffe
diff --git a/src/caffe/common.hpp b/src/caffe/common.hpp
index 67177a664b5ce88d766015148df30285a41858ca..18c5b413b7b3180a2136df835c43c399c25caf78 100644 (file)
--- a/src/caffe/common.hpp
+++ b/src/caffe/common.hpp
LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
}
+#define DISABLE_COPY_AND_ASSIGN(classname) \
+ private:\
+ classname(const classname&);\
+ classname& operator=(const classname&)
+
#define INSTANTIATE_CLASS(classname) \
template class classname<float>; \
template class classname<double>
diff --git a/src/caffe/layer.hpp b/src/caffe/layer.hpp
index 551ed3b3c8259fe27e41082aa48cb47bb7c224cf..b82f03806f88935252993a2046d061b6ca73036f 100644 (file)
--- a/src/caffe/layer.hpp
+++ b/src/caffe/layer.hpp
vector<Blob<Dtype>*>* bottom);
// Returns the vector of parameters.
- vector<Blob<Dtype> >& params() {
+ vector<shared_ptr<Blob<Dtype> > >& params() {
return blobs_;
}
// The protobuf that stores the layer parameters
LayerParameter layer_param_;
// The vector that stores the parameters as a set of blobs.
- vector<Blob<Dtype> > blobs_;
+ vector<shared_ptr<Blob<Dtype> > > blobs_;
// Forward functions
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
LOG(WARNING) << "Using CPU code as backup.";
return Backward_cpu(top, propagate_down, bottom);
};
+
+ DISABLE_COPY_AND_ASSIGN(Layer);
}; // class Layer
// Forward and backward wrappers. You should implement the cpu and
param->CopyFrom(layer_param_);
param->clear_blobs();
for (int i = 0; i < blobs_.size(); ++i) {
- blobs_[i].ToProto(param->add_blobs(), write_diff);
+ blobs_[i]->ToProto(param->add_blobs(), write_diff);
}
}
index 90e6d66d9557343c8950e8a3c82cfab9cfdb4252..8453fd51980cff30e99ff2e6ea813c2c686d853f 100644 (file)
return new PoolingLayer<Dtype>(param);
} else if (type == "relu") {
return new ReLULayer<Dtype>(param);
+ } else if (type == "softmax") {
+ return new SoftmaxLayer<Dtype>(param);
+ } else if (type == "multinomial_logistic_loss") {
+ return new MultinomialLogisticLossLayer<Dtype>(param);
} else {
LOG(FATAL) << "Unknown filler name: " << type;
}
index 849e10610f77ddf54cc5d261174a516ad361a762..9560e47d36b47eb949e9c329d793b3fd8c3689af 100644 (file)
template <typename Dtype>
void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom.size(), 1) << "Im2col Layer takes a single blob as input.";
- CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output.";
+ CHECK_EQ(bottom.size(), 1) << "Conv Layer takes a single blob as input.";
+ CHECK_EQ(top->size(), 1) << "Conv Layer takes a single blob as output.";
KSIZE_ = this->layer_param_.kernelsize();
STRIDE_ = this->layer_param_.stride();
GROUP_ = this->layer_param_.group();
HEIGHT_ = bottom[0]->height();
WIDTH_ = bottom[0]->width();
NUM_OUTPUT_ = this->layer_param_.num_output();
+ CHECK_GT(NUM_OUTPUT_, 0);
CHECK_EQ(CHANNELS_ % GROUP_, 0);
// The im2col result buffer would only hold one image at a time to avoid
// overly large memory usage.
this->blobs_.resize(1);
}
// Intialize the weight
- this->blobs_[0].Reshape(1, 1, NUM_OUTPUT_, K_);
+ this->blobs_[0].reset(new Blob<Dtype>(1, 1, NUM_OUTPUT_, K_));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(
GetFiller<Dtype>(this->layer_param_.weight_filler()));
- weight_filler->Fill(&this->blobs_[0]);
+ weight_filler->Fill(this->blobs_[0].get());
// If necessary, intiialize and fill the bias term
if (biasterm_) {
- this->blobs_[1].Reshape(1, 1, 1, NUM_OUTPUT_);
+ this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, NUM_OUTPUT_));
shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.bias_filler()));
- bias_filler->Fill(&this->blobs_[1]);
+ bias_filler->Fill(this->blobs_[1].get());
bias_multiplier_.reset(new SyncedMemory(N_ * sizeof(Dtype)));
Dtype* bias_multiplier_data =
reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
Dtype* col_data = col_buffer_.mutable_cpu_data();
- const Dtype* weight = this->blobs_[0].cpu_data();
+ const Dtype* weight = this->blobs_[0]->cpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
// third, add bias
if (biasterm_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
- N_, 1, (Dtype)1., this->blobs_[1].cpu_data(),
+ N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(),
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
(Dtype)1., top_data + (*top)[0]->offset(n));
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
Dtype* col_data = col_buffer_.mutable_gpu_data();
- const Dtype* weight = this->blobs_[0].gpu_data();
+ const Dtype* weight = this->blobs_[0]->gpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
// third, add bias
if (biasterm_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
- N_, 1, (Dtype)1., this->blobs_[1].gpu_data(),
+ N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
(Dtype)1., top_data + (*top)[0]->offset(n));
}
Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
- const Dtype* weight = this->blobs_[0].cpu_data();
- Dtype* weight_diff = this->blobs_[0].mutable_cpu_diff();
+ const Dtype* weight = this->blobs_[0]->cpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
Dtype* col_data = col_buffer_.mutable_cpu_data();
Dtype* bias_diff = NULL;
if (biasterm_) {
- bias_diff = this->blobs_[1].mutable_cpu_diff();
- memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1].count());
+ bias_diff = this->blobs_[1]->mutable_cpu_diff();
+ memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1]->count());
for (int n = 0; n < NUM_; ++n) {
caffe_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
1., top_diff + top[0]->offset(n),
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
- memset(weight_diff, 0., sizeof(Dtype) * this->blobs_[0].count());
+ memset(weight_diff, 0., sizeof(Dtype) * this->blobs_[0]->count());
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
- const Dtype* weight = this->blobs_[0].gpu_data();
- Dtype* weight_diff = this->blobs_[0].mutable_gpu_diff();
+ const Dtype* weight = this->blobs_[0]->gpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
Dtype* col_data = col_buffer_.mutable_gpu_data();
Dtype* bias_diff = NULL;
if (biasterm_) {
- bias_diff = this->blobs_[1].mutable_gpu_diff();
+ bias_diff = this->blobs_[1]->mutable_gpu_diff();
CUDA_CHECK(cudaMemset(bias_diff, 0.,
- sizeof(Dtype) * this->blobs_[1].count()));
+ sizeof(Dtype) * this->blobs_[1]->count()));
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
1., top_diff + top[0]->offset(n),
int col_offset = K_ * N_;
int top_offset = M_ * N_;
CUDA_CHECK(cudaMemset(weight_diff, 0.,
- sizeof(Dtype) * this->blobs_[0].count()));
+ sizeof(Dtype) * this->blobs_[0]->count()));
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
index b39b5689c918248a5496a311a11c3987fcf08496..55b51c6fdc3af4277599e99ca15e09e5ecba85d8 100644 (file)
this->blobs_.resize(1);
}
// Intialize the weight
- this->blobs_[0].Reshape(1, 1, N_, K_);
+ this->blobs_[0].reset(new Blob<Dtype>(1, 1, N_, K_));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(
GetFiller<Dtype>(this->layer_param_.weight_filler()));
- weight_filler->Fill(&this->blobs_[0]);
+ weight_filler->Fill(this->blobs_[0].get());
// If necessary, intiialize and fill the bias term
if (biasterm_) {
- this->blobs_[1].Reshape(1, 1, 1, N_);
+ this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, N_));
shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.bias_filler()));
- bias_filler->Fill(&this->blobs_[1]);
+ bias_filler->Fill(this->blobs_[1].get());
bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
for (int i = 0; i < M_; ++i) {
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
- const Dtype* weight = this->blobs_[0].cpu_data();
+ const Dtype* weight = this->blobs_[0]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
- this->blobs_[1].cpu_data(), (Dtype)1., top_data);
+ this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
}
}
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
// Gradient with respect to weight
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
- bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
+ bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
if (biasterm_) {
// Gradient with respect to bias
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
- this->blobs_[1].mutable_cpu_diff());
+ this->blobs_[1]->mutable_cpu_diff());
}
if (propagate_down) {
// Gradient with respect to bottom data
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
+ top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_cpu_diff());
}
return Dtype(0);
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
- const Dtype* weight = this->blobs_[0].gpu_data();
+ const Dtype* weight = this->blobs_[0]->gpu_data();
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
- this->blobs_[1].gpu_data(), (Dtype)1., top_data);
+ this->blobs_[1]->gpu_data(), (Dtype)1., top_data);
}
}
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
// Gradient with respect to weight
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
- bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
+ bottom_data, top_diff, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
if (biasterm_) {
// Gradient with respect to bias
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
- (Dtype)0., this->blobs_[1].mutable_gpu_diff());
+ (Dtype)0., this->blobs_[1]->mutable_gpu_diff());
}
if (propagate_down) {
// Gradient with respect to bottom data
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
+ top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_gpu_diff());
}
return Dtype(0);
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 93969b159683f8807c5d30411acf88a7e170e3f9..75c9043977a9a064375c2b03e996fc1cf5cc8e74 100644 (file)
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
// set the input blobs
for (int i = 0; i < param.bottom_size(); ++i) {
const string& blob_name = param.bottom(i);
- blobs_.push_back(Blob<Dtype>(*bottom[i]));
+ CHECK_GT(bottom[i]->count(), 0);
+ shared_ptr<Blob<Dtype> > blob_pointer(
+ new Blob<Dtype>(bottom[i]->num(), bottom[i]->channels(),
+ bottom[i]->height(), bottom[i]->width()));
+ blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
net_input_blob_indices_.push_back(i);
blob_name_to_idx[blob_name] = i;
available_blobs.insert(blob_name);
}
// For each layer, set up their input and output
- layers_.resize(param.layers_size());
bottom_vecs_.resize(param.layers_size());
top_vecs_.resize(param.layers_size());
- for (int i = 0; i < param.top_size(); ++i) {
+ for (int i = 0; i < param.layers_size(); ++i) {
const LayerConnection& layer_connection = param.layers(i);
const LayerParameter& layer_param = layer_connection.layer();
- layers_[i].reset(GetLayer<Dtype>(layer_param));
+ layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
+ layer_names_.push_back(layer_param.name());
+ LOG(INFO) << "Creating Layer " << layer_param.name();
// Figure out this layer's input and output
for (int j = 0; j < layer_connection.bottom_size(); ++j) {
const string& blob_name = layer_connection.bottom(j);
LOG(FATAL) << "Unknown blob input " << blob_name <<
" to layer" << j;
}
+ LOG(INFO) << layer_param.name() << " <- " << blob_name;
bottom_vecs_[i].push_back(
- &blobs_[blob_name_to_idx[blob_name]]);
+ blobs_[blob_name_to_idx[blob_name]].get());
available_blobs.erase(blob_name);
}
for (int j = 0; j < layer_connection.top_size(); ++j) {
if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) {
LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
}
- blobs_.push_back(Blob<Dtype>());
+ LOG(INFO) << layer_param.name() << " -> " << blob_name;
+ shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
+ blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
available_blobs.insert(blob_name);
- top_vecs_[i].push_back(&blobs_[blob_names_.size() - 1]);
+ top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
}
}
+ LOG(INFO) << "Checking top blobs.";
// In the end, check if all remaining available blobs are top blobs.
for (int i = 0; i < param.top_size(); ++i) {
const string& blob_name = param.top(i);
if (blob_name_to_idx.find(blob_name) == blob_name_to_idx.end()) {
- LOG(FATAL) << "Unknown blob input " << blob_name;
+ LOG(FATAL) << "Unknown blob output " << blob_name;
}
net_output_blob_indices_.push_back(blob_name_to_idx[blob_name]);
available_blobs.erase(blob_name);
LOG(INFO) << "Setting up the layers.";
for (int i = 0; i < layers_.size(); ++i) {
+ LOG(INFO) << "Setting up " << layer_names_[i];
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
}
+ LOG(INFO) << "Network initialization done.";
}
template <typename Dtype>
vector<Blob<Dtype>*>* top) {
// Copy bottom to internal bottom
for (int i = 0; i < bottom.size(); ++i) {
- blobs_[net_input_blob_indices_[i]] = *bottom[i];
+ memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(),
+ bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count());
}
for (int i = 0; i < layers_.size(); ++i) {
layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
continue;
}
LOG(INFO) << "Loading source layer " << source_layer_name;
- vector<Blob<Dtype> >& target_blobs = layers_[target_layer_id]->params();
+ vector<shared_ptr<Blob<Dtype> > >& target_blobs =
+ layers_[target_layer_id]->params();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
- target_blobs[j].FromProto(source_layer.blobs(j));
+ target_blobs[j]->FromProto(source_layer.blobs(j));
}
}
}
diff --git a/src/caffe/net.hpp b/src/caffe/net.hpp
index debb59ec74ad594eb7b30ac30941294cdfbde9d2..45ea708dbca32e11505a30758ebfbf16b606a8a6 100644 (file)
--- a/src/caffe/net.hpp
+++ b/src/caffe/net.hpp
#include <vector>
#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
// returns the network name.
const string& name() { return name_; }
+ // returns the layer names
+ const vector<string>& layer_names() { return layer_names_; }
+ // returns the blob names
+ const vector<string>& blob_names() { return blob_names_; }
protected:
// Individual layers in the net
vector<string> layer_names_;
// blobs stores the blobs that store intermediate results between the
// layers.
- vector<Blob<Dtype> > blobs_;
+ vector<shared_ptr<Blob<Dtype> > > blobs_;
vector<string> blob_names_;
// bottom_vecs stores the vectors containing the input for each layer, except
// for the first layer whose bottom vec is provided by the network's input.
diff --git a/src/caffe/proto/lenet.prototxt b/src/caffe/proto/lenet.prototxt
+++ /dev/null
@@ -1,89 +0,0 @@
-name: "LeNet"
-bottom: "data"
-bottom: "label"
-layers {
- layer {
- name: "conv1"
- type: "conv"
- num_output: 20
- kernelsize: 5
- stride: 1
- }
- bottom: "data"
- top: "conv1"
-}
-layers {
- layer {
- name: "pool1"
- type: "pool"
- kernelsize: 2
- stride: 2
- pool: MAX
- }
- bottom: "conv1"
- top: "pool1"
-}
-layers {
- layer {
- name: "conv2"
- type: "conv"
- num_output: 50
- kernelsize: 5
- stride: 1
- }
- bottom: "pool1"
- top: "conv2"
-}
-layers {
- layer {
- name: "pool2"
- type: "pool"
- kernelsize: 2
- stride: 2
- pool: MAX
- }
- bottom: "conv2"
- top: "pool2"
-}
-layers {
- layer {
- name: "ip1"
- type: "innerproduct"
- num_output: 500
- }
- bottom: "pool2"
- top: "ip1"
-}
-layers {
- layer {
- name: "relu1"
- type: "relu"
- }
- bottom: "ip1"
- top: "relu1"
-}
-layers {
- layer {
- name: "ip2"
- type: "innerproduct"
- num_output: 10
- }
- bottom: "relu1"
- top: "ip2"
-}
-layers {
- layer {
- name: "prob"
- type: "softmax"
- }
- bottom: "ip2"
- top: "prob"
-}
-layers {
- layer {
- name: "loss"
- type: "softmaxloss"
- }
- bottom: "prob"
- bottom: "label"
-}
\ No newline at end of file
diff --git a/src/caffe/test/lenet.hpp b/src/caffe/test/lenet.hpp
--- /dev/null
+++ b/src/caffe/test/lenet.hpp
@@ -0,0 +1,100 @@
+#ifndef CAFFE_TEST_LENET_HPP_
+#define CAFFE_TEST_LENET_HPP_
+
+#include <string>
+
+namespace caffe {
+
+const char* kLENET = "name: \"LeNet\"\n\
+bottom: \"data\"\n\
+bottom: \"label\"\n\
+layers {\n\
+ layer {\n\
+ name: \"conv1\"\n\
+ type: \"conv\"\n\
+ num_output: 20\n\
+ kernelsize: 5\n\
+ stride: 1\n\
+ }\n\
+ bottom: \"data\"\n\
+ top: \"conv1\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"pool1\"\n\
+ type: \"pool\"\n\
+ kernelsize: 2\n\
+ stride: 2\n\
+ pool: MAX\n\
+ }\n\
+ bottom: \"conv1\"\n\
+ top: \"pool1\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"conv2\"\n\
+ type: \"conv\"\n\
+ num_output: 50\n\
+ kernelsize: 5\n\
+ stride: 1\n\
+ }\n\
+ bottom: \"pool1\"\n\
+ top: \"conv2\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"pool2\"\n\
+ type: \"pool\"\n\
+ kernelsize: 2\n\
+ stride: 2\n\
+ pool: MAX\n\
+ }\n\
+ bottom: \"conv2\"\n\
+ top: \"pool2\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"ip1\"\n\
+ type: \"innerproduct\"\n\
+ num_output: 500\n\
+ }\n\
+ bottom: \"pool2\"\n\
+ top: \"ip1\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"relu1\"\n\
+ type: \"relu\"\n\
+ }\n\
+ bottom: \"ip1\"\n\
+ top: \"relu1\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"ip2\"\n\
+ type: \"innerproduct\"\n\
+ num_output: 10\n\
+ }\n\
+ bottom: \"relu1\"\n\
+ top: \"ip2\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"prob\"\n\
+ type: \"softmax\"\n\
+ }\n\
+ bottom: \"ip2\"\n\
+ top: \"prob\"\n\
+}\n\
+layers {\n\
+ layer {\n\
+ name: \"loss\"\n\
+ type: \"multinomial_logistic_loss\"\n\
+ }\n\
+ bottom: \"prob\"\n\
+ bottom: \"label\"\n\
+}";
+
+} // namespace caffe
+
+#endif
index 31bb9190a1c73b74469a92cba278021d0eb7a3a0..ba76ed1d8047011d2dd16007c4bc65dd36cf1344 100644 (file)
EXPECT_EQ(this->blob_->count(), 120);
}
-TYPED_TEST(BlobSimpleTest, TestCopyConstructor) {
- Blob<TypeParam> source(2, 3, 4, 5);
- FillerParameter filler_param;
- UniformFiller<TypeParam> filler(filler_param);
- filler.Fill(&source);
- Blob<TypeParam> target(source);
- const TypeParam* source_data = source.cpu_data();
- const TypeParam* target_data = target.cpu_data();
- EXPECT_EQ(target.num(), source.num());
- EXPECT_EQ(target.channels(), source.channels());
- EXPECT_EQ(target.height(), source.height());
- EXPECT_EQ(target.width(), source.width());
- EXPECT_EQ(target.count(), source.count());
- for (int i = 0; i < source.count(); ++i) {
- EXPECT_EQ(source_data[i], target_data[i]);
- }
-}
-
}
diff --git a/src/caffe/test/test_gradient_check_util.hpp b/src/caffe/test/test_gradient_check_util.hpp
index dbaa7bacee1233c836b5ae59045b2401fd33072f..0c34861b2da88df1c21044f5aacec2f7ab3ded54 100644 (file)
// First, figure out what blobs we need to check against.
vector<Blob<Dtype>*> blobs_to_check;
for (int i = 0; i < layer.params().size(); ++i) {
- blobs_to_check.push_back(&layer.params()[i]);
+ blobs_to_check.push_back(layer.params()[i].get());
}
if (check_bottom < 0) {
for (int i = 0; i < bottom.size(); ++i) {
diff --git a/src/caffe/test/test_innerproduct_layer.cpp b/src/caffe/test/test_innerproduct_layer.cpp
index 212b92f445f4bb5aa9630e49d9be20168a68199a..3ccd34e307a3bb45f5d7d245d2ceaa1af29a644a 100644 (file)
namespace caffe {
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
-
+
template <typename Dtype>
class InnerProductLayerTest : public ::testing::Test {
protected:
diff --git a/src/caffe/test/test_net_proto.cpp b/src/caffe/test/test_net_proto.cpp
--- /dev/null
@@ -0,0 +1,47 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cstring>
+#include <cuda_runtime.h>
+#include <google/protobuf/text_format.h>
+#include <gtest/gtest.h>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/test/lenet.hpp"
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+class NetProtoTest : public ::testing::Test {};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(NetProtoTest, Dtypes);
+
+TYPED_TEST(NetProtoTest, TestSetup) {
+ NetParameter net_param;
+ string lenet_string(kLENET);
+ // Load the network
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ lenet_string, &net_param));
+ // check if things are right
+ EXPECT_EQ(net_param.layers_size(), 9);
+ EXPECT_EQ(net_param.bottom_size(), 2);
+ EXPECT_EQ(net_param.top_size(), 0);
+
+ // Now, initialize a network using the parameter
+ shared_ptr<Blob<TypeParam> > data(new Blob<TypeParam>(10, 1, 28, 28));
+ shared_ptr<Blob<TypeParam> > label(new Blob<TypeParam>(10, 1, 1, 1));
+ vector<Blob<TypeParam>*> bottom_vec;
+ bottom_vec.push_back(data.get());
+ bottom_vec.push_back(label.get());
+
+ Net<TypeParam> caffe_net(net_param, bottom_vec);
+ EXPECT_EQ(caffe_net.layer_names().size(), 9);
+ EXPECT_EQ(caffe_net.blob_names().size(), 10);
+}
+
+} // namespace caffe
index 3429618683943d95997a10eff1ee7d884b822148..a5d0c9fb2294fadf20a2c9c41a78a98f69c83676 100644 (file)
EXPECT_EQ(this->blob_top_->width(), 2);
}
-TYPED_TEST(PoolingLayerTest, TestGPUMax) {
- LayerParameter layer_param;
- layer_param.set_kernelsize(3);
- layer_param.set_stride(2);
- layer_param.set_pool(LayerParameter_PoolMethod_MAX);
- Caffe::set_mode(Caffe::CPU);
- PoolingLayer<TypeParam> layer(layer_param);
- layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
- layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
- Blob<TypeParam> blob_reference(*this->blob_top_);
- Caffe::set_mode(Caffe::GPU);
- layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
- for (int i = 0; i < blob_reference.count(); ++i) {
- EXPECT_EQ(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i])
- << "debug: index " << i;
- }
-}
-
-TYPED_TEST(PoolingLayerTest, TestGPUAve) {
- LayerParameter layer_param;
- layer_param.set_kernelsize(3);
- layer_param.set_stride(2);
- layer_param.set_pool(LayerParameter_PoolMethod_AVE);
- Caffe::set_mode(Caffe::CPU);
- PoolingLayer<TypeParam> layer(layer_param);
- layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
- layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
- Blob<TypeParam> blob_reference(*this->blob_top_);
- Caffe::set_mode(Caffe::GPU);
- layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
- for (int i = 0; i < blob_reference.count(); ++i) {
- EXPECT_GE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] - 1e-4)
- << "debug: index " << i;
- EXPECT_LE(blob_reference.cpu_data()[i], this->blob_top_->cpu_data()[i] + 1e-4)
- << "debug: index " << i;
- }
-}
-
/*
TYPED_TEST(PoolingLayerTest, PrintGPUBackward) {
LayerParameter layer_param;