summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 5d14903)
raw | patch | inline | side by side (parent: 5d14903)
author | Yangqing Jia <jiayq84@gmail.com> | |
Wed, 6 Nov 2013 23:27:26 +0000 (15:27 -0800) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Wed, 6 Nov 2013 23:27:26 +0000 (15:27 -0800) |
index 6b5ee78aee54a50d93dbaf19f7fa97c6f404038f..6a6b86dcdf8f5137d272fca54d02f6afb1dde16c 100644 (file)
char key_cstr[100];
leveldb::WriteBatch* batch = new leveldb::WriteBatch();
for (int line_id = 0; line_id < lines.size(); ++line_id) {
- ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second,
- &datum);
+ if (!ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second,
+ &datum)) {
+ continue;
+ };
// sequential
sprintf(key_cstr, "%08d_%s", line_id, lines[line_id].first.c_str());
string value;
index b0130de1ccc7abac6c21b509a7a49dd8c476243e..f05a8c5b8f893976da10679c473fdec1ae8e7a57 100644 (file)
sum_blob.set_channels(datum.channels());
sum_blob.set_height(datum.height());
sum_blob.set_width(datum.width());
+ const int data_size = datum.channels() * datum.height() * datum.width();
for (int i = 0; i < datum.data().size(); ++i) {
sum_blob.add_data(0.);
}
// just a dummy operation
datum.ParseFromString(it->value().ToString());
const string& data = datum.data();
+ CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size();
for (int i = 0; i < data.size(); ++i) {
sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
}
index b8f983baa3ed72097f4ad94971cba5fa7ac36a02..e5ae41e813a7a0f1c617f3b4074686ad38f3512b 100644 (file)
return 0;
}
- Caffe::SetDevice(0);
- Caffe::set_mode(Caffe::GPU);
+ //Caffe::SetDevice(0);
+ Caffe::set_mode(Caffe::CPU);
SolverParameter solver_param;
ReadProtoFromTextFile(argv[1], &solver_param);
index 4e084058cff29f3ea1cca491b6012eed55c7bbc6..42836427a6eea6e2edd75c33bacba9556f0bfd06 100644 (file)
#include "caffe/solver.hpp"
+
namespace caffe {
template <typename Dtype>
index 03df4b2e84a146e4ddf9524402bf637bd78ac4f1..e37e740b60cd387b3c63b8b1e40291fcb8342b8a 100644 (file)
WriteProtoToBinaryFile(proto, filename.c_str());
}
-void ReadImageToDatum(const string& filename, const int label, Datum* datum);
+bool ReadImageToDatum(const string& filename, const int label, Datum* datum);
} // namespace caffe
index cbc7c047c7559e2980ee77ffd26f609873d29b47..0de1a2e285b28fc27a934509e8a0e1a8279408d9 100644 (file)
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Solver<Dtype>::Restore(resume_file);
}
+ next_snapshot_ = this->iter_ + this->param_.snapshot();
// the main loop.
LOG(INFO) << "Waiting for incoming updates...";
total_received += count;
}
LOG(INFO) << "Received " << total_received << " variables.";
- // Check Error
- if (!data_stream) {
- LOG(ERROR) << "Error in receiving.";
+ // Check error: if there are any error in the receiving phase, we will not
+ // trust the passed in update.
+ if (data_stream.error()) {
+ LOG(ERROR) << "Error in receiving. Error code: " << data_stream.error().message();
} else {
// If the read is successful, update the network.
this->iter_ += incoming_iter;
}
LOG(INFO) << "Sent " << total_sent << " variables.";
data_stream.flush();
- data_stream.close();
}
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
+ next_display_ = this->iter_ + this->param_.display();
while (this->iter_++ < this->param_.max_iter()) {
Dtype loss = this->net_->ForwardBackward(bottom_vec);
ComputeUpdateValue();
if (this->param_.display() && this->iter_ > next_display_) {
LOG(INFO) << "Iteration " << this->iter_ << ", loss = " << loss;
- next_display_ += this->param_.display();
+ next_display_ = this->iter_ + this->param_.display();
}
}
LOG(INFO) << "Optimization Done.";
template <typename Dtype>
void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port());
- CHECK(data_stream) << "Error in connection.";
+ if (!data_stream) {
+ LOG(FATAL) << "Unable to connect. Error code: " << data_stream.error().message();
+ }
data_stream.write(reinterpret_cast<char*>(&receive_only), sizeof(receive_only));
if (!receive_only) {
LOG(INFO) << "Sending local changes.";
total_sent += count;
}
LOG(INFO) << "Sent " << total_sent << " variables.";
+ CHECK(!data_stream.error()) << "Error in sending. Error code: "
+ << data_stream.error().message();
}// else {
// LOG(INFO) << "Not sending local changes. Receive only.";
//}
memset(net_params[param_id]->mutable_cpu_diff(), 0,
net_params[param_id]->count() * sizeof(Dtype));
}
+ CHECK(!data_stream.error()) << "Error in communication. Error code: "
+ << data_stream.error().message();
LOG(INFO) << "Received " << total_received << " variables.";
// Set the next send iter.
next_send_iter_ = this->iter_ + this->param_.communication_interval();
index cbe306d60c25bff4fc59482d4a090edd9378925f..939df91164c94371df279198dcec47e4da8201ca 100644 (file)
optional int32 iter = 1; // The current iteration
optional string learned_net = 2; // The file that stores the learned net.
repeated BlobProto history = 3; // The history for sgd solvers
-}
\ No newline at end of file
+}
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index a3c520f0eeb06ff1b55854f2960f0e1f4a16b73c..9bca5be48b883d27f967fad7687801bb09479611 100644 (file)
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
}
-void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
+bool ReadImageToDatum(const string& filename, const int label, Datum* datum) {
cv::Mat cv_img;
cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
- CHECK(cv_img.data) << "Could not open or find the image.";
+ if (!cv_img.data) {
+ LOG(ERROR) << "Could not open or find file " << filename;
+ return false;
+ }
datum->set_channels(3);
datum->set_height(cv_img.rows);
datum->set_width(cv_img.cols);
}
}
}
+ return true;
}
} // namespace caffe