summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 5a1e541)
raw | patch | inline | side by side (parent: 5a1e541)
author | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 27 Sep 2013 23:59:43 +0000 (16:59 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 27 Sep 2013 23:59:43 +0000 (16:59 -0700) |
src/caffe/blob.cpp | patch | blob | history | |
src/caffe/blob.hpp | patch | blob | history | |
src/caffe/net.cpp | patch | blob | history | |
src/caffe/net.hpp | patch | blob | history | |
src/caffe/optimization/solver.cpp | [new file with mode: 0644] | patch | blob |
src/caffe/optimization/solver.hpp | patch | blob | history | |
src/caffe/proto/caffe.proto | patch | blob | history | |
src/caffe/util/math_functions.cpp | patch | blob | history | |
src/caffe/util/math_functions.hpp | patch | blob | history |
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 616274018cc9c1ba23101315522eaab9861a5da9..35e5b04adbddc82585733defb6eb5d25529ef9c5 100644 (file)
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
// Copyright 2013 Yangqing Jia
+#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "caffe/blob.hpp"
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
+ int old_count = count_;
CHECK_GE(num, 0);
CHECK_GE(channels, 0);
CHECK_GE(height, 0);
width_ = width;
count_ = num_ * channels_ * height_ * width_;
if (count_) {
- data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ if (old_count != count_) {
+ data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ }
} else {
data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
// We will perform update based on where the data is located.
}
+template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
+ if (num_ != source.num() || channels_ != source.channels() ||
+ height_ != source.height() || width_ != source.width()) {
+ if (reshape) {
+ Reshape(source.num(), source.channels(), source.height(), source.width());
+ } else {
+ LOG(FATAL) << "Trying to copy blobs of different sizes.";
+ }
+ }
+ switch (Caffe::mode()) {
+ case Caffe::GPU:
+ if (copy_diff) {
+ CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
+ sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+ } else {
+ CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
+ sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+ }
+ break;
+ case Caffe::CPU:
+ if (copy_diff) {
+ memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+ sizeof(Dtype) * count_);
+ } else {
+ memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+ sizeof(Dtype) * count_);
+ }
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode.";
+ }
+}
+
template <typename Dtype>
void Blob<Dtype>::FromProto(const BlobProto& proto) {
Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
diff --git a/src/caffe/blob.hpp b/src/caffe/blob.hpp
index f0e19c277de09ae49591a4d774cba501376042e5..f31d3b0f693c7b3f3a60412e390b7f0be3064b32 100644 (file)
--- a/src/caffe/blob.hpp
+++ b/src/caffe/blob.hpp
const int w = 0) const {
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
+ // Copy from source. If copy_diff is false, we copy the data; if copy_diff
+ // is true, we copy the diff.
+ void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
+ bool reshape = false);
inline Dtype data_at(const int n, const int c, const int h,
const int w) const {
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index c6dfce19c038239b841b00b24c18c36bee8af025..22d27436e77422bf418bf126af12709ec6eb6aa1 100644 (file)
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
// For each layer, set up their input and output
bottom_vecs_.resize(param.layers_size());
top_vecs_.resize(param.layers_size());
+ bottom_id_vecs_.resize(param.layers_size());
+ top_id_vecs_.resize(param.layers_size());
for (int i = 0; i < param.layers_size(); ++i) {
const LayerConnection& layer_connection = param.layers(i);
const LayerParameter& layer_param = layer_connection.layer();
LOG(INFO) << layer_param.name() << " <- " << blob_name;
bottom_vecs_[i].push_back(
blobs_[blob_name_to_idx[blob_name]].get());
+ bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
available_blobs.erase(blob_name);
}
for (int j = 0; j < layer_connection.top_size(); ++j) {
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
available_blobs.insert(blob_name);
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
+ top_id_vecs_[i].push_back(blob_names_.size() - 1);
}
}
LOG(INFO) << "Checking top blobs.";
for (int i = 0; i < layers_.size(); ++i) {
LOG(INFO) << "Setting up " << layer_names_[i];
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
- vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i].params();
+ vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i]->params();
for (int j = 0; j < layer_params.size(); ++j) {
params_.push_back(layer_params[j]);
}
vector<Blob<Dtype>*>* top) {
// Copy bottom to internal bottom
for (int i = 0; i < bottom.size(); ++i) {
- memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(),
- bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count());
+ blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
}
for (int i = 0; i < layers_.size(); ++i) {
layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
}
// Copy internal top to top
for (int i = 0; i < (*top).size(); ++i) {
- NOT_IMPLEMENTED;
+ (*top)[i]->CopyFrom(*blobs_[net_output_blob_indices_[i]]);
}
}
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
}
- for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
- param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+ for (int i = 0; i < net_output_blob_indices_.size(); ++i) {
+ param->add_top(blob_names_[net_output_blob_indices_[i]]);
}
for (int i = 0; i < layers_.size(); ++i) {
LayerConnection* layer_connection = param->add_layers();
+ for (int j = 0; j < bottom_id_vecs_[i].size(); ++i) {
+ layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
+ }
+ for (int j = 0; j < top_id_vecs_[i].size(); ++i) {
+ layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
+ }
+ LayerParameter* layer_parameter = layer_connection->mutable_layer();
+ layers_[i]->ToProto(layer_parameter);
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::Update() {
+ for (int i = 0; i < params_.size(); ++i) {
+ params_[i]->Update();
}
}
diff --git a/src/caffe/net.hpp b/src/caffe/net.hpp
index 719267c6402c57c7f30b4582d176410ad79773c0..1f1a80309256ae84eb284e63ae5617f844d78120 100644 (file)
--- a/src/caffe/net.hpp
+++ b/src/caffe/net.hpp
// been provided during the forward pass.
Dtype Backward();
+ Dtype ForwardBackWard(const vector<Blob<Dtype>* > & bottom,
+ vector<Blob<Dtype>*>* top) {
+ Forward(bottom, top);
+ return Backward();
+ }
+
// For an already initialized net, CopyTrainedLayersFrom() copies the already
// trained layers from another net parameter instance.
void CopyTrainedLayersFrom(const NetParameter& param);
inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
// returns the parameters
vector<shared_ptr<Blob<Dtype> > >& params() { return params_; };
+ // Updates the network
+ void Update();
protected:
// Individual layers in the net
// bottom_vecs stores the vectors containing the input for each layer, except
// for the first layer whose bottom vec is provided by the network's input.
vector<vector<Blob<Dtype>*> > bottom_vecs_;
+ vector<vector<int> > bottom_id_vecs_;
// top_vecs stores the vectors containing the output for each layer, except
// for the last layer (likewise)
vector<vector<Blob<Dtype>*> > top_vecs_;
+ vector<vector<int> > top_id_vecs_;
// blob indices for the input and the output of the net.
vector<int> net_input_blob_indices_;
vector<int> net_output_blob_indices_;
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
--- /dev/null
@@ -0,0 +1,113 @@
+// Copyright Yangqing Jia 2013
+
+#include <fstream>
+#include <string>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/net.hpp"
+#include "caffe/optimization/solver.hpp"
+
+using std::stringstream;
+using std::ofstream;
+
+namespace caffe {
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(Net<Dtype>* net) {
+ net_ = net;
+ LOG(INFO) << "Solving net " << net_->name();
+ iter_ = 0;
+ // For a network that is trained by the solver, no bottom or top vecs
+ // should be given, and we will just provide dummy vecs.
+ vector<Blob<Dtype>*> bottom_vec;
+ vector<Blob<Dtype>*> top_vec;
+ while (iter_++ < param_.max_iter()) {
+ Dtype loss = net_->ForwardBackWard(bottom_vec, &top_vec);
+ ComputeUpdateValue();
+ net->Update();
+
+ // Check if we need to do snapshot
+ if (iter_ % param_.snapshot()) {
+ // TODO(Yangqing): snapshot
+ }
+ LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
+ }
+ LOG(INFO) << "Optimization Done.";
+}
+
+template <typename Dtype>
+void Solver<Dtype>::Snapshot(bool is_final) {
+ NetParameter net_param;
+ net_->ToProto(&net_param);
+ stringstream ss;
+ ss << param_.snapshot_prefix();
+ if (is_final) {
+ ss << "_final";
+ } else {
+ ss << "_iter_" << iter_;
+ }
+ ofstream output_file;
+ output_file.open(ss.str().c_str());
+ CHECK(net_param.SerializeToOstream(&output_file));
+ output_file.close();
+}
+
+template <typename Dtype>
+Dtype SGDSolver<Dtype>::GetLearningRate() {
+ Dtype rate;
+ const string& lr_policy = this->param_.lr_policy();
+ if (lr_policy == "fixed") {
+ rate = this->param_.base_lr();
+ } else if (lr_policy == "exp") {
+ rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
+ } else if (lr_policy == "inv") {
+ rate = this->param_.base_lr() *
+ pow(Dtype(1) + this->param_.gamma() * this->iter_,
+ this->param_.power());
+ } else {
+ LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
+ }
+ rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
+ return rate;
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue() {
+ // First of all, see if we need to initialize the history
+ vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
+ if (this->iter_ == 1 && this->param_.momentum() > 0) {
+ LOG(INFO) << "Using momentum " << this->param_.momentum();
+ for (int i = 0; i < net_params.size(); ++i) {
+ const Blob<Dtype>* net_param = net_params[i].get();
+ history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+ net_param->num(), net_param->channels(), net_param->height(),
+ net_param->width())));
+ }
+ }
+ // get the learning rate
+ Dtype rate = GetLearningRate();
+ if (this->param_.momentum == 0) {
+ for (int i = 0; i < net_params.size(); ++i) {
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ caffe_scal(net_params[i]->count(), rate,
+ net_params[i]->mutable_cpu_data());
+ break;
+ case Caffe::GPU:
+ caffe_gpu_scal(net_params[i]->count(), rate,
+ net_params[i]->mutable_gpu_data());
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+ }
+ } else {
+ NOT_IMPLEMENTED;
+ }
+}
+
+
+
+INSTANTIATE_CLASS(Solver);
+
+} // namespace caffe
\ No newline at end of file
index 0c680e34d01ed30d228d12dad4575bf6eeba98d4..0a78d88000178387b26d5e84f2498a37a3a9c8ed 100644 (file)
namespace caffe {
+template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param)
: param_(param) {}
- void Solve(Net* net);
+ // The main entry of the solver function.
+ void Solve(Net<Dtype>* net);
protected:
+ // Get the update value for the current iteration.
+ virtual void ComputeUpdateValue() = 0;
+ void Snapshot(bool is_final = false);
SolverParameter param_;
+ int iter_;
+ Net<Dtype>* net_;
+
+ DISABLE_COPY_AND_ASSIGN(Solver);
};
+template <typename Dtype>
+class SGDSolver : public Solver<Dtype> {
+ public:
+ explicit SGDSolver(const SolverParameter& param)
+ : Solver<Dtype>(param) {}
+
+ protected:
+ Dtype GetLearningRate();
+ virtual void ComputeUpdateValue();
+ // history maintains the historical momentum data.
+ vector<shared_ptr<Blob<Dtype> > > history_;
+};
+
+
} // namspace caffe
#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_
\ No newline at end of file
index 732c2eecfda153dbf111d0c573de4e2aa7544078..9d691d271ca0c508c94763088fb8db5bf225aed6 100644 (file)
optional float gamma = 8; // The parameter to compute the learning rate.
optional float power = 9; // The parameter to compute the learning rate.
optional float momentum = 10; // The momentum value.
+
+ optional string snapshot_prefix = 11; // The prefix for the snapshot.
}
\ No newline at end of file
index 1949a703d75a9e71c91182dff0a7c6ae24689ef8..7cd3b26226e2ad697864f0b7765b12241ce449b7 100644 (file)
cblas_dscal(N, alpha, X, 1);
}
+template <>
+void caffe_gpu_scal<float>(const int N, const float alpha, float *X) {
+ CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
+template <>
+void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
+ CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
vsSqr(n, a, y);
index 822ef31be74acea4ed4a53851229c67bf32fbfed..f09afe38eedfc2a1223ce6fc57af2dc45eddd4df 100644 (file)
template <typename Dtype>
void caffe_scal(const int N, const Dtype alpha, Dtype *X);
+template <typename Dtype>
+void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X);
+
template <typename Dtype>
void caffe_sqr(const int N, const Dtype* a, Dtype* y);