]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blobdiff - examples/train_net.cpp
solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / examples / train_net.cpp
index 3abb1c3955557486d54534666ea8be2e0b78861d..d84619b2d44c10028fa5fdd215a020c03b677b9a 100644 (file)
 using namespace caffe;
 
 int main(int argc, char** argv) {
 using namespace caffe;
 
 int main(int argc, char** argv) {
-  if (argc < 3) {
-    LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
-               << "[resume_point_file]";
+  ::google::InitGoogleLogging(argv[0]);
+  if (argc < 2) {
+    LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
     return 0;
   }
 
     return 0;
   }
 
-  cudaSetDevice(0);
+  Caffe::SetDevice(0);
   Caffe::set_mode(Caffe::GPU);
   Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TRAIN);
-
-  NetParameter net_param;
-  ReadProtoFromTextFile(argv[1], &net_param);
-  vector<Blob<float>*> bottom_vec;
-  Net<float> caffe_net(net_param, bottom_vec);
 
   SolverParameter solver_param;
 
   SolverParameter solver_param;
-  ReadProtoFromTextFile(argv[2], &solver_param);
+  ReadProtoFromTextFile(argv[1], &solver_param);
 
 
-  LOG(ERROR) << "Starting Optimization";
+  LOG(INFO) << "Starting Optimization";
   SGDSolver<float> solver(solver_param);
   SGDSolver<float> solver(solver_param);
-  if (argc == 4) {
-    LOG(ERROR) << "Resuming from " << argv[3];
-    solver.Solve(&caffe_net, argv[3]);
+  if (argc == 3) {
+    LOG(INFO) << "Resuming from " << argv[2];
+    solver.Solve(argv[2]);
   } else {
   } else {
-    solver.Solve(&caffe_net);
+    solver.Solve();
   }
   }
-  LOG(ERROR) << "Optimization Done.";
+  LOG(INFO) << "Optimization Done.";
 
   return 0;
 }
 
   return 0;
 }