]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/common.hpp
pylint and code cleaning
[jacinto-ai/caffe-jacinto.git] / src / caffe / common.hpp
1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_COMMON_HPP_
4 #define CAFFE_COMMON_HPP_
6 #include <boost/shared_ptr.hpp>
7 #include <cublas_v2.h>
8 #include <cuda.h>
9 #include <curand.h>
10 //cuda driver types
11 #include <driver_types.h>
12 #include <glog/logging.h>
13 #include <mkl_vsl.h>
15 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
16 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
17 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
18 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
20 #define CUDA_POST_KERNEL_CHECK \
21   if (cudaSuccess != cudaPeekAtLastError()) {\
22     LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
23   }
25 #define INSTANTIATE_CLASS(classname) \
26   template class classname<float>; \
27   template class classname<double>
29 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
32 namespace caffe {
34 // We will use the boost shared_ptr instead of the new C++11 one mainly
35 // because cuda does not work (at least now) well with C++11 features.
36 using boost::shared_ptr;
38 // For backward compatibility we will just use 512 threads per block
39 const int CAFFE_CUDA_NUM_THREADS = 512;
41 inline int CAFFE_GET_BLOCKS(const int N) {
42   return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
43 }
45 // A singleton class to hold common caffe stuff, such as the handler that
46 // caffe is going to use for cublas.
47 class Caffe {
48  public:
49   ~Caffe();
50   static Caffe& Get();
51   enum Brew { CPU, GPU };
52   enum Phase { TRAIN, TEST };
54   // The getters for the variables.
55   // Returns the cublas handle.
56   static cublasHandle_t cublas_handle();
57   // Returns the curand generator.
58   static curandGenerator_t curand_generator();
59   // Returns the MKL random stream.
60   static VSLStreamStatePtr vsl_stream();
61   // Returns the mode: running on CPU or GPU.
62   static Brew mode();
63   // Returns the phase: TRAIN or TEST.
64   static Phase phase();
65   // The setters for the variables
66   // Sets the mode.
67   static void set_mode(Brew mode);
68   // Sets the phase.
69   static void set_phase(Phase phase);
70   // Sets the random seed of both MKL and curand
71   static void set_random_seed(const unsigned int seed);
73  private:
74   // The private constructor to avoid duplicate instantiation.
75   Caffe();
77  protected:
78   static shared_ptr<Caffe> singleton_;
79   cublasHandle_t cublas_handle_;
80   curandGenerator_t curand_generator_;
81   VSLStreamStatePtr vsl_stream_;
82   Brew mode_;
83   Phase phase_;
84 };
87 }  // namespace caffe
89 #endif  // CAFFE_COMMON_HPP_