1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_COMMON_HPP_
4 #define CAFFE_COMMON_HPP_
6 #include <boost/shared_ptr.hpp>
7 #include <cublas_v2.h>
8 #include <cuda.h>
9 #include <curand.h>
10 // cuda driver types
11 #include <driver_types.h>
12 #include <glog/logging.h>
13 #include <mkl_vsl.h>
15 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
16 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
17 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
18 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
20 #define CUDA_POST_KERNEL_CHECK \
21 if (cudaSuccess != cudaPeekAtLastError()) {\
22 LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
23 }
25 #define DISABLE_COPY_AND_ASSIGN(classname) \
26 private:\
27 classname(const classname&);\
28 classname& operator=(const classname&)
30 #define INSTANTIATE_CLASS(classname) \
31 template class classname<float>; \
32 template class classname<double>
34 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
37 namespace caffe {
39 // We will use the boost shared_ptr instead of the new C++11 one mainly
40 // because cuda does not work (at least now) well with C++11 features.
41 using boost::shared_ptr;
43 // For backward compatibility we will just use 512 threads per block
44 const int CAFFE_CUDA_NUM_THREADS = 512;
46 inline int CAFFE_GET_BLOCKS(const int N) {
47 return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
48 }
50 // A singleton class to hold common caffe stuff, such as the handler that
51 // caffe is going to use for cublas.
52 class Caffe {
53 public:
54 ~Caffe();
55 static Caffe& Get();
56 enum Brew { CPU, GPU };
57 enum Phase { TRAIN, TEST };
59 // The getters for the variables.
60 // Returns the cublas handle.
61 static cublasHandle_t cublas_handle();
62 // Returns the curand generator.
63 static curandGenerator_t curand_generator();
64 // Returns the MKL random stream.
65 static VSLStreamStatePtr vsl_stream();
66 // Returns the mode: running on CPU or GPU.
67 static Brew mode();
68 // Returns the phase: TRAIN or TEST.
69 static Phase phase();
70 // The setters for the variables
71 // Sets the mode.
72 static void set_mode(Brew mode);
73 // Sets the phase.
74 static void set_phase(Phase phase);
75 // Sets the random seed of both MKL and curand
76 static void set_random_seed(const unsigned int seed);
78 private:
79 // The private constructor to avoid duplicate instantiation.
80 Caffe();
82 protected:
83 static shared_ptr<Caffe> singleton_;
84 cublasHandle_t cublas_handle_;
85 curandGenerator_t curand_generator_;
86 VSLStreamStatePtr vsl_stream_;
87 Brew mode_;
88 Phase phase_;
89 };
92 } // namespace caffe
94 #endif // CAFFE_COMMON_HPP_