diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index c6dfce19c038239b841b00b24c18c36bee8af025..22d27436e77422bf418bf126af12709ec6eb6aa1 100644 (file)
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
// For each layer, set up their input and output
bottom_vecs_.resize(param.layers_size());
top_vecs_.resize(param.layers_size());
// For each layer, set up their input and output
bottom_vecs_.resize(param.layers_size());
top_vecs_.resize(param.layers_size());
+ bottom_id_vecs_.resize(param.layers_size());
+ top_id_vecs_.resize(param.layers_size());
for (int i = 0; i < param.layers_size(); ++i) {
const LayerConnection& layer_connection = param.layers(i);
const LayerParameter& layer_param = layer_connection.layer();
for (int i = 0; i < param.layers_size(); ++i) {
const LayerConnection& layer_connection = param.layers(i);
const LayerParameter& layer_param = layer_connection.layer();
LOG(INFO) << layer_param.name() << " <- " << blob_name;
bottom_vecs_[i].push_back(
blobs_[blob_name_to_idx[blob_name]].get());
LOG(INFO) << layer_param.name() << " <- " << blob_name;
bottom_vecs_[i].push_back(
blobs_[blob_name_to_idx[blob_name]].get());
+ bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
available_blobs.erase(blob_name);
}
for (int j = 0; j < layer_connection.top_size(); ++j) {
available_blobs.erase(blob_name);
}
for (int j = 0; j < layer_connection.top_size(); ++j) {
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
available_blobs.insert(blob_name);
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
available_blobs.insert(blob_name);
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
+ top_id_vecs_[i].push_back(blob_names_.size() - 1);
}
}
LOG(INFO) << "Checking top blobs.";
}
}
LOG(INFO) << "Checking top blobs.";
for (int i = 0; i < layers_.size(); ++i) {
LOG(INFO) << "Setting up " << layer_names_[i];
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
for (int i = 0; i < layers_.size(); ++i) {
LOG(INFO) << "Setting up " << layer_names_[i];
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
- vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i].params();
+ vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i]->params();
for (int j = 0; j < layer_params.size(); ++j) {
params_.push_back(layer_params[j]);
}
for (int j = 0; j < layer_params.size(); ++j) {
params_.push_back(layer_params[j]);
}
vector<Blob<Dtype>*>* top) {
// Copy bottom to internal bottom
for (int i = 0; i < bottom.size(); ++i) {
vector<Blob<Dtype>*>* top) {
// Copy bottom to internal bottom
for (int i = 0; i < bottom.size(); ++i) {
- memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(),
- bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count());
+ blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
}
for (int i = 0; i < layers_.size(); ++i) {
layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
}
// Copy internal top to top
for (int i = 0; i < (*top).size(); ++i) {
}
for (int i = 0; i < layers_.size(); ++i) {
layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
}
// Copy internal top to top
for (int i = 0; i < (*top).size(); ++i) {
- NOT_IMPLEMENTED;
+ (*top)[i]->CopyFrom(*blobs_[net_output_blob_indices_[i]]);
}
}
}
}
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
}
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
}
- for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
- param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+ for (int i = 0; i < net_output_blob_indices_.size(); ++i) {
+ param->add_top(blob_names_[net_output_blob_indices_[i]]);
}
for (int i = 0; i < layers_.size(); ++i) {
LayerConnection* layer_connection = param->add_layers();
}
for (int i = 0; i < layers_.size(); ++i) {
LayerConnection* layer_connection = param->add_layers();
+ for (int j = 0; j < bottom_id_vecs_[i].size(); ++i) {
+ layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
+ }
+ for (int j = 0; j < top_id_vecs_[i].size(); ++i) {
+ layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
+ }
+ LayerParameter* layer_parameter = layer_connection->mutable_layer();
+ layers_[i]->ToProto(layer_parameter);
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::Update() {
+ for (int i = 0; i < params_.size(); ++i) {
+ params_[i]->Update();
}
}
}
}