diff --git a/examples/train_net.cpp b/examples/train_net.cpp
index 06ca12137929afb4848630a5ee6f850ff9708036..ce62616b1186aa398c07cb8edfbad6df8ac0405a 100644 (file)
--- a/examples/train_net.cpp
+++ b/examples/train_net.cpp
// This is a simple script that allows one to quickly train a network whose
// parameters are specified by text format protocol buffers.
// Usage:
-// train_net net_proto_file solver_proto_file
+// train_net net_proto_file solver_proto_file [resume_point_file]
#include <cuda_runtime.h>
using namespace caffe;
int main(int argc, char** argv) {
- cudaSetDevice(0);
- Caffe::set_mode(Caffe::GPU);
- Caffe::set_phase(Caffe::TRAIN);
-
- NetParameter net_param;
- ReadProtoFromTextFile(argv[1], &net_param);
- vector<Blob<float>*> bottom_vec;
- Net<float> caffe_net(net_param, bottom_vec);
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc < 2) {
+ LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
+ return 0;
+ }
SolverParameter solver_param;
- ReadProtoFromTextFile(argv[2], &solver_param);
+ ReadProtoFromTextFile(argv[1], &solver_param);
- LOG(ERROR) << "Starting Optimization";
+ LOG(INFO) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
- solver.Solve(&caffe_net);
- LOG(ERROR) << "Optimization Done.";
+ if (argc == 3) {
+ LOG(INFO) << "Resuming from " << argv[2];
+ solver.Solve(argv[2]);
+ } else {
+ solver.Solve();
+ }
+ LOG(INFO) << "Optimization Done.";
return 0;
}