summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: a86d48a)
raw | patch | inline | side by side (parent: a86d48a)
author | Sergei Nikolaev <snikolaev@nvidia.com> | |
Tue, 17 Oct 2017 15:33:49 +0000 (08:33 -0700) | ||
committer | Sergei Nikolaev <snikolaev@nvidia.com> | |
Tue, 17 Oct 2017 15:33:49 +0000 (08:33 -0700) |
include/caffe/solver_factory.hpp | patch | blob | history | |
tools/caffe.cpp | patch | blob | history |
index 71c420757aac86bee444b072995280c50a256d6f..4fea5c8018c46093d13821df558edab094c84e67 100644 (file)
}
// Get a solver using a SolverParameter.
- static Solver* CreateSolver(const SolverParameter& param, Solver* root_solver = NULL) {
+ static Solver* CreateSolver(const SolverParameter& param, Solver* root_solver = NULL, int ranks = 0) {
const string& type = param.type();
- const size_t rank = param.device_id();
+ const size_t rank = ranks == 1 ? 0 : param.device_id();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index f74527b7b50737e0ad54b343d1ec51c63ea7e24a..41e8f52355c56936b36a840d0a6ec94e31b91f66 100644 (file)
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));
- shared_ptr<caffe::Solver> solver(caffe::SolverRegistry::CreateSolver(solver_param));
+ shared_ptr<caffe::Solver> solver(caffe::SolverRegistry::CreateSolver(solver_param,
+ nullptr, gpus.size()));
solver->SetActionFunction(signal_handler.GetActionFunction());
if (FLAGS_snapshot.size()) {