3abb1c3955557486d54534666ea8be2e0b78861d
1 // Copyright 2013 Yangqing Jia
2 //
3 // This is a simple script that allows one to quickly train a network whose
4 // parameters are specified by text format protocol buffers.
5 // Usage:
6 // train_net net_proto_file solver_proto_file [resume_point_file]
8 #include <cuda_runtime.h>
10 #include <cstring>
12 #include "caffe/caffe.hpp"
14 using namespace caffe;
16 int main(int argc, char** argv) {
17 if (argc < 3) {
18 LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
19 << "[resume_point_file]";
20 return 0;
21 }
23 cudaSetDevice(0);
24 Caffe::set_mode(Caffe::GPU);
25 Caffe::set_phase(Caffe::TRAIN);
27 NetParameter net_param;
28 ReadProtoFromTextFile(argv[1], &net_param);
29 vector<Blob<float>*> bottom_vec;
30 Net<float> caffe_net(net_param, bottom_vec);
32 SolverParameter solver_param;
33 ReadProtoFromTextFile(argv[2], &solver_param);
35 LOG(ERROR) << "Starting Optimization";
36 SGDSolver<float> solver(solver_param);
37 if (argc == 4) {
38 LOG(ERROR) << "Resuming from " << argv[3];
39 solver.Solve(&caffe_net, argv[3]);
40 } else {
41 solver.Solve(&caffe_net);
42 }
43 LOG(ERROR) << "Optimization Done.";
45 return 0;
46 }