1 #ifndef CAFFE_COMMON_HPP_
2 #define CAFFE_COMMON_HPP_
4 #include <boost/shared_ptr.hpp>
5 #include <cublas_v2.h>
6 #include <cuda.h>
7 #include <curand.h>
8 #include <glog/logging.h>
9 #include <mkl_vsl.h>
11 #include "driver_types.h"
13 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
14 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
15 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
16 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
18 #define CUDA_POST_KERNEL_CHECK \
19 if (cudaSuccess != cudaPeekAtLastError()) {\
20 LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
21 }
23 #define INSTANTIATE_CLASS(classname) \
24 template class classname<float>; \
25 template class classname<double>
27 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
29 namespace caffe {
31 // We will use the boost shared_ptr instead of the new C++11 one mainly
32 // because cuda does not work (at least now) well with C++11 features.
33 using boost::shared_ptr;
35 // For backward compatibility we will just use 512 threads per block
36 const int CAFFE_CUDA_NUM_THREADS = 512;
38 inline int CAFFE_GET_BLOCKS(const int N) {
39 return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
40 }
42 // A singleton class to hold common caffe stuff, such as the handler that
43 // caffe is going to use for cublas.
44 class Caffe {
45 public:
46 ~Caffe();
47 static Caffe& Get();
48 enum Brew { CPU, GPU };
49 enum Phase { TRAIN, TEST};
51 // The getters for the variables.
52 static cublasHandle_t cublas_handle();
53 static curandGenerator_t curand_generator();
54 static VSLStreamStatePtr vsl_stream();
55 static Brew mode();
56 static Phase phase();
57 // The setters for the variables
58 static void set_mode(Brew mode);
59 static void set_phase(Phase phase);
60 static void set_random_seed(const unsigned int seed);
61 private:
62 Caffe();
63 static shared_ptr<Caffe> singleton_;
64 cublasHandle_t cublas_handle_;
65 curandGenerator_t curand_generator_;
66 VSLStreamStatePtr vsl_stream_;
67 Brew mode_;
68 Phase phase_;
69 };
71 } // namespace caffe
73 #endif // CAFFE_COMMON_HPP_