]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - include/caffe/common.hpp
conv_layer bugfix
[jacinto-ai/caffe-jacinto.git] / include / caffe / common.hpp
1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_COMMON_HPP_
4 #define CAFFE_COMMON_HPP_
6 #include <boost/shared_ptr.hpp>
7 #include <cublas_v2.h>
8 #include <cuda.h>
9 #include <curand.h>
10 // cuda driver types
11 #include <driver_types.h>
12 #include <glog/logging.h>
13 #include <mkl_vsl.h>
15 // various checks for different function calls.
16 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
17 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
18 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
19 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
21 // After a kernel is executed, this will check the error and if there is one,
22 // exit loudly.
23 #define CUDA_POST_KERNEL_CHECK \
24   if (cudaSuccess != cudaPeekAtLastError()) \
25     LOG(FATAL) << "Cuda kernel failed. Error: " \
26         << cudaGetErrorString(cudaPeekAtLastError())
28 // Disable the copy and assignment operator for a class.
29 #define DISABLE_COPY_AND_ASSIGN(classname) \
30 private:\
31   classname(const classname&);\
32   classname& operator=(const classname&)
34 // Instantiate a class with float and double specifications.
35 #define INSTANTIATE_CLASS(classname) \
36   template class classname<float>; \
37   template class classname<double>
39 // A simple macro to mark codes that are not implemented, so that when the code
40 // is executed we will see a fatal log.
41 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
44 namespace caffe {
46 // We will use the boost shared_ptr instead of the new C++11 one mainly
47 // because cuda does not work (at least now) well with C++11 features.
48 using boost::shared_ptr;
51 // We will use 1024 threads per block, which requires cuda sm_2x or above.
52 const int CAFFE_CUDA_NUM_THREADS = 1024;
55 inline int CAFFE_GET_BLOCKS(const int N) {
56   return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
57 }
60 // A singleton class to hold common caffe stuff, such as the handler that
61 // caffe is going to use for cublas, curand, etc.
62 class Caffe {
63  public:
64   ~Caffe();
65   inline static Caffe& Get() {
66     if (!singleton_.get()) {
67       singleton_.reset(new Caffe());
68     }
69     return *singleton_;
70   }
71   enum Brew { CPU, GPU };
72   enum Phase { TRAIN, TEST };
74   // The getters for the variables.
75   // Returns the cublas handle.
76   inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
77   // Returns the curand generator.
78   inline static curandGenerator_t curand_generator() {
79     return Get().curand_generator_;
80   }
81   // Returns the MKL random stream.
82   inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
83   // Returns the mode: running on CPU or GPU.
84   inline static Brew mode() { return Get().mode_; }
85   // Returns the phase: TRAIN or TEST.
86   inline static Phase phase() { return Get().phase_; }
87   // The setters for the variables
88   // Sets the mode.
89   inline static void set_mode(Brew mode) { Get().mode_ = mode; }
90   // Sets the phase.
91   inline static void set_phase(Phase phase) { Get().phase_ = phase; }
92   // Sets the random seed of both MKL and curand
93   static void set_random_seed(const unsigned int seed);
95  protected:
96   cublasHandle_t cublas_handle_;
97   curandGenerator_t curand_generator_;
98   VSLStreamStatePtr vsl_stream_;
99   Brew mode_;
100   Phase phase_;
101   static shared_ptr<Caffe> singleton_;
103  private:
104   // The private constructor to avoid duplicate instantiation.
105   Caffe();
107   DISABLE_COPY_AND_ASSIGN(Caffe);
108 };
110 }  // namespace caffe
112 #endif  // CAFFE_COMMON_HPP_