X-Git-Url: https://git.ti.com/gitweb?p=jacinto-ai%2Fcaffe-jacinto.git;a=blobdiff_plain;f=examples%2Ftrain_net.cpp;h=d84619b2d44c10028fa5fdd215a020c03b677b9a;hp=3abb1c3955557486d54534666ea8be2e0b78861d;hb=82b912be849a7bec9bfee92a8e5d81182f4130f2;hpb=25a865cd8ba0995c89907990fedaa357282b9a64;ds=sidebyside diff --git a/examples/train_net.cpp b/examples/train_net.cpp index 3abb1c39..d84619b2 100644 --- a/examples/train_net.cpp +++ b/examples/train_net.cpp @@ -14,33 +14,27 @@ using namespace caffe; int main(int argc, char** argv) { - if (argc < 3) { - LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file " - << "[resume_point_file]"; + ::google::InitGoogleLogging(argv[0]); + if (argc < 2) { + LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]"; return 0; } - cudaSetDevice(0); + Caffe::SetDevice(0); Caffe::set_mode(Caffe::GPU); - Caffe::set_phase(Caffe::TRAIN); - - NetParameter net_param; - ReadProtoFromTextFile(argv[1], &net_param); - vector*> bottom_vec; - Net caffe_net(net_param, bottom_vec); SolverParameter solver_param; - ReadProtoFromTextFile(argv[2], &solver_param); + ReadProtoFromTextFile(argv[1], &solver_param); - LOG(ERROR) << "Starting Optimization"; + LOG(INFO) << "Starting Optimization"; SGDSolver solver(solver_param); - if (argc == 4) { - LOG(ERROR) << "Resuming from " << argv[3]; - solver.Solve(&caffe_net, argv[3]); + if (argc == 3) { + LOG(INFO) << "Resuming from " << argv[2]; + solver.Solve(argv[2]); } else { - solver.Solve(&caffe_net); + solver.Solve(); } - LOG(ERROR) << "Optimization Done."; + LOG(INFO) << "Optimization Done."; return 0; }