solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / src / caffe / common.cpp
index aecdc6e12666345606eb3e011cf6547f639ec385..a70a28087d391a1c750303ce4d5a98afd3662d47 100644 (file)
@@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
   VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
 }
 
+void Caffe::SetDevice(const int device_id) {
+  int current_device;
+  CUDA_CHECK(cudaGetDevice(&current_device));
+  if (current_device == device_id) {
+    return;
+  }
+  if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
+  if (Get().curand_generator_) {
+    CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
+  }
+  CUDA_CHECK(cudaSetDevice(device_id));
+  CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
+  CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
+      CURAND_RNG_PSEUDO_DEFAULT));
+  CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
+      time(NULL)));
+}
+
 void Caffe::DeviceQuery() {
   cudaDeviceProp prop;
   int device;