]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - examples/train_net.cpp
need backward computation, and train_net resume point. Not debugged.
[jacinto-ai/caffe-jacinto.git] / examples / train_net.cpp
1 // Copyright 2013 Yangqing Jia
2 //
3 // This is a simple script that allows one to quickly train a network whose
4 // parameters are specified by text format protocol buffers.
5 // Usage:
6 //    train_net net_proto_file solver_proto_file [resume_point_file]
8 #include <cuda_runtime.h>
10 #include <cstring>
12 #include "caffe/caffe.hpp"
14 using namespace caffe;
16 int main(int argc, char** argv) {
17   cudaSetDevice(0);
18   Caffe::set_mode(Caffe::GPU);
19   Caffe::set_phase(Caffe::TRAIN);
21   NetParameter net_param;
22   ReadProtoFromTextFile(argv[1], &net_param);
23   vector<Blob<float>*> bottom_vec;
24   Net<float> caffe_net(net_param, bottom_vec);
26   SolverParameter solver_param;
27   ReadProtoFromTextFile(argv[2], &solver_param);
29   LOG(ERROR) << "Starting Optimization";
30   SGDSolver<float> solver(solver_param);
31   if (argc == 4) {
32     LOG(ERROR) << "Resuming from " << argv[3];
33     solver.Solve(&caffe_net, argv[3]);
34   } else {
35     solver.Solve(&caffe_net);
36   }
37   LOG(ERROR) << "Optimization Done.";
39   return 0;
40 }