7362b233780d7d205906df1e53bbab2950e86fa5
1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_COMMON_HPP_
4 #define CAFFE_COMMON_HPP_
6 #include <boost/shared_ptr.hpp>
7 #include <cublas_v2.h>
8 #include <cuda.h>
9 #include <curand.h>
10 // cuda driver types
11 #include <driver_types.h>
12 #include <glog/logging.h>
13 #include <mkl_vsl.h>
15 // various checks for different function calls.
16 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
17 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
18 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
19 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
21 // After a kernel is executed, this will check the error and if there is one,
22 // exit loudly.
23 #define CUDA_POST_KERNEL_CHECK \
24 if (cudaSuccess != cudaPeekAtLastError()) \
25 LOG(FATAL) << "Cuda kernel failed. Error: " \
26 << cudaGetErrorString(cudaPeekAtLastError())
28 // Disable the copy and assignment operator for a class.
29 #define DISABLE_COPY_AND_ASSIGN(classname) \
30 private:\
31 classname(const classname&);\
32 classname& operator=(const classname&)
34 // Instantiate a class with float and double specifications.
35 #define INSTANTIATE_CLASS(classname) \
36 template class classname<float>; \
37 template class classname<double>
39 // A simple macro to mark codes that are not implemented, so that when the code
40 // is executed we will see a fatal log.
41 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
44 namespace caffe {
46 // We will use the boost shared_ptr instead of the new C++11 one mainly
47 // because cuda does not work (at least now) well with C++11 features.
48 using boost::shared_ptr;
51 // We will use 1024 threads per block, which requires cuda sm_2x or above.
52 const int CAFFE_CUDA_NUM_THREADS = 1024;
55 inline int CAFFE_GET_BLOCKS(const int N) {
56 return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
57 }
60 // A singleton class to hold common caffe stuff, such as the handler that
61 // caffe is going to use for cublas, curand, etc.
62 class Caffe {
63 public:
64 ~Caffe();
65 inline static Caffe& Get() {
66 if (!singleton_.get()) {
67 singleton_.reset(new Caffe());
68 }
69 return *singleton_;
70 }
71 enum Brew { CPU, GPU };
72 enum Phase { TRAIN, TEST };
74 // The getters for the variables.
75 // Returns the cublas handle.
76 inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
77 // Returns the curand generator.
78 inline static curandGenerator_t curand_generator() {
79 return Get().curand_generator_;
80 }
81 // Returns the MKL random stream.
82 inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
83 // Returns the mode: running on CPU or GPU.
84 inline static Brew mode() { return Get().mode_; }
85 // Returns the phase: TRAIN or TEST.
86 inline static Phase phase() { return Get().phase_; }
87 // The setters for the variables
88 // Sets the mode.
89 inline static void set_mode(Brew mode) { Get().mode_ = mode; }
90 // Sets the phase.
91 inline static void set_phase(Phase phase) { Get().phase_ = phase; }
92 // Sets the random seed of both MKL and curand
93 static void set_random_seed(const unsigned int seed);
94 // Prints the current GPU status.
95 static void DeviceQuery();
97 protected:
98 cublasHandle_t cublas_handle_;
99 curandGenerator_t curand_generator_;
100 VSLStreamStatePtr vsl_stream_;
101 Brew mode_;
102 Phase phase_;
103 static shared_ptr<Caffe> singleton_;
105 private:
106 // The private constructor to avoid duplicate instantiation.
107 Caffe();
109 DISABLE_COPY_AND_ASSIGN(Caffe);
110 };
112 } // namespace caffe
114 #endif // CAFFE_COMMON_HPP_