edc6c1f39bbb546414e85d52ef7f697e337d9a87
1 #include <cstring>
2 #include <cuda_runtime.h>
4 #include "gtest/gtest.h"
5 #include "caffeine/common.hpp"
6 #include "caffeine/syncedmem.hpp"
8 namespace caffeine {
10 class CommonTest : public ::testing::Test {};
12 TEST_F(CommonTest, TestCublasHandler) {
13 int cuda_device_id;
14 CUDA_CHECK(cudaGetDevice(&cuda_device_id));
15 EXPECT_TRUE(Caffeine::cublas_handle());
16 }
18 TEST_F(CommonTest, TestVslStream) {
19 EXPECT_TRUE(Caffeine::vsl_stream());
20 }
22 TEST_F(CommonTest, TestBrewMode) {
23 EXPECT_EQ(Caffeine::mode(), Caffeine::CPU);
24 Caffeine::set_mode(Caffeine::GPU);
25 EXPECT_EQ(Caffeine::mode(), Caffeine::GPU);
26 }
28 TEST_F(CommonTest, TestPhase) {
29 EXPECT_EQ(Caffeine::phase(), Caffeine::TRAIN);
30 Caffeine::set_phase(Caffeine::TEST);
31 EXPECT_EQ(Caffeine::phase(), Caffeine::TEST);
32 }
34 TEST_F(CommonTest, TestRandSeedCPU) {
35 SyncedMemory data_a(10 * sizeof(int));
36 SyncedMemory data_b(10 * sizeof(int));
37 Caffeine::set_random_seed(1701);
38 viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(),
39 10, (int*)data_a.mutable_cpu_data(), 0.5);
40 Caffeine::set_random_seed(1701);
41 viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(),
42 10, (int*)data_b.mutable_cpu_data(), 0.5);
43 for (int i = 0; i < 10; ++i) {
44 EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
45 ((const int*)(data_b.cpu_data()))[i]);
46 }
47 }
50 TEST_F(CommonTest, TestRandSeedGPU) {
51 SyncedMemory data_a(10 * sizeof(unsigned int));
52 SyncedMemory data_b(10 * sizeof(unsigned int));
53 Caffeine::set_random_seed(1701);
54 CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
55 (unsigned int*)data_a.mutable_gpu_data(), 10));
56 Caffeine::set_random_seed(1701);
57 CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
58 (unsigned int*)data_b.mutable_gpu_data(), 10));
59 for (int i = 0; i < 10; ++i) {
60 EXPECT_EQ(((const unsigned int*)(data_a.cpu_data()))[i],
61 ((const unsigned int*)(data_b.cpu_data()))[i]);
62 }
63 }
66 } // namespace caffeine