diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp
index c254d70e4371fbb3bbb45d3efa1cc42ad3516762..a70a28087d391a1c750303ce4d5a98afd3662d47 100644 (file)
--- a/src/caffe/common.cpp
+++ b/src/caffe/common.cpp
}
Caffe::~Caffe() {
- if (!cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
- if (!curand_generator_) {
+ if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
+ if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
- if (!vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
+ if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
};
void Caffe::set_random_seed(const unsigned int seed) {
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
}
+void Caffe::SetDevice(const int device_id) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(¤t_device));
+ if (current_device == device_id) {
+ return;
+ }
+ if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
+ if (Get().curand_generator_) {
+ CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
+ }
+ CUDA_CHECK(cudaSetDevice(device_id));
+ CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
+ CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
+ CURAND_RNG_PSEUDO_DEFAULT));
+ CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
+ time(NULL)));
+}
+
void Caffe::DeviceQuery() {
cudaDeviceProp prop;
int device;
- CUDA_CHECK(cudaGetDevice(&device));
+ if (cudaSuccess != cudaGetDevice(&device)) {
+ printf("No cuda device present.\n");
+ return;
+ }
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
printf("Major revision number: %d\n", prop.major);
printf("Minor revision number: %d\n", prop.minor);