diff --git a/examples/train_net.cpp b/examples/train_net.cpp
index 3abb1c3955557486d54534666ea8be2e0b78861d..d84619b2d44c10028fa5fdd215a020c03b677b9a 100644 (file)
--- a/examples/train_net.cpp
+++ b/examples/train_net.cpp
using namespace caffe;
int main(int argc, char** argv) {
using namespace caffe;
int main(int argc, char** argv) {
- if (argc < 3) {
- LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
- << "[resume_point_file]";
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc < 2) {
+ LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
return 0;
}
return 0;
}
- cudaSetDevice(0);
+ Caffe::SetDevice(0);
Caffe::set_mode(Caffe::GPU);
Caffe::set_mode(Caffe::GPU);
- Caffe::set_phase(Caffe::TRAIN);
-
- NetParameter net_param;
- ReadProtoFromTextFile(argv[1], &net_param);
- vector<Blob<float>*> bottom_vec;
- Net<float> caffe_net(net_param, bottom_vec);
SolverParameter solver_param;
SolverParameter solver_param;
- ReadProtoFromTextFile(argv[2], &solver_param);
+ ReadProtoFromTextFile(argv[1], &solver_param);
- LOG(ERROR) << "Starting Optimization";
+ LOG(INFO) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
SGDSolver<float> solver(solver_param);
- if (argc == 4) {
- LOG(ERROR) << "Resuming from " << argv[3];
- solver.Solve(&caffe_net, argv[3]);
+ if (argc == 3) {
+ LOG(INFO) << "Resuming from " << argv[2];
+ solver.Solve(argv[2]);
} else {
} else {
- solver.Solve(&caffe_net);
+ solver.Solve();
}
}
- LOG(ERROR) << "Optimization Done.";
+ LOG(INFO) << "Optimization Done.";
return 0;
}
return 0;
}