]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
BatchReindexLayer fix
authorSergei Nikolaev <snikolaev@nvidia.com>
Fri, 19 Jan 2018 03:56:44 +0000 (19:56 -0800)
committerSergei Nikolaev <snikolaev@nvidia.com>
Fri, 19 Jan 2018 03:56:44 +0000 (19:56 -0800)
python/caffe/io.py
src/caffe/util/gpu_amax.cu
src/caffe/util/gpu_asum.cu
src/caffe/util/gpu_sumsq.cu

index cee5ace2e885054291a3174ed0e936fac4db42d3..5f933876cc058fb29b6891b257abf1e4eb683b69 100644 (file)
@@ -256,7 +256,14 @@ class Transformer:
             if len(ms) != 3:
                 raise ValueError('Mean shape invalid')
             if ms != self.inputs[in_][1:]:
-                raise ValueError('Mean shape incompatible with input shape.')
+                print(self.inputs[in_])
+                in_shape = self.inputs[in_][1:]
+                m_min, m_max = mean.min(), mean.max()
+                normal_mean = (mean - m_min) / (m_max - m_min)
+                mean = resize_image(normal_mean.transpose((1,2,0)),
+                                    in_shape[1:]).transpose((2,0,1)) * \
+                                    (m_max - m_min) + m_min
+                # raise ValueError('Mean shape incompatible with input shape.')
         self.mean[in_] = mean
 
     def set_input_scale(self, in_, scale):
index 577069aa7c4a6c4bc8c4b4c4077b31d5dac8e889..079815056afe7ce9eeba25a8e016bedd33069a6b 100644 (file)
@@ -12,38 +12,40 @@ namespace caffe {
 SHMEM(amax);
 CAFFE_GPU_SHMEM(amax);
 
+#define BLOCK_REDUCE_AMAX(TNUM) \
+if (BlockSize >= (TNUM) * 2) { \
+  if (tid < (TNUM)) { \
+    tmax_replace(st, sdata[tid + (TNUM)]); \
+  } \
+  __syncthreads(); \
+}
+
+#define REDUCE_AMAX(TNUM) \
+if (tid + (TNUM) < thread_count) { \
+  tmax_replace(st, sdata[tid + (TNUM)]); \
+  __syncthreads(); \
+}
+
 ///////////////////////////////////// AMAX REDUCTION ///////////////////////////////////
 
 template<unsigned int BlockSize, typename T>
 __device__ void amax_reduce_block(volatile T *sdata, T my_max, unsigned int tid) {
+  const int thread_count = blockDim.x * blockDim.y * blockDim.z;
   volatile T* st = sdata + tid;
-
   tassign(st, my_max);
   __syncthreads();
 
   // do reduction in shared mem
-  if (BlockSize >= 512) {
-    if (tid < 256) {
-      tmax_replace(st, sdata[tid + 256]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 256) {
-    if (tid < 128) {
-      tmax_replace(st, sdata[tid + 128]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 128) {
-    if (tid < 64) {
-      tmax_replace(st, sdata[tid + 64]);
-    }
-    __syncthreads();
-  }
+  BLOCK_REDUCE_AMAX(256)
+  BLOCK_REDUCE_AMAX(128)
+  BLOCK_REDUCE_AMAX(64)
   if (tid < 32) {
-    for (int i = 32; i > 0; i >>= 1) {
-      tmax_replace(st, sdata[tid + i]);
-    }
+    REDUCE_AMAX(32)
+    REDUCE_AMAX(16)
+    REDUCE_AMAX(8)
+    REDUCE_AMAX(4)
+    REDUCE_AMAX(2)
+    REDUCE_AMAX(1)
   }
 }
 
@@ -122,7 +124,7 @@ __global__ void amax_reduce_kernel(unsigned int n, const T *in, TR *out, int gro
 template <typename T, typename TR>
 void gpu_amax_t(const int n, const T* x, TR* result, int group) {
   CHECK_LT(group, REGRESSION_GROUPS_MAX);
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   const bool po2 = is_pow2(n);
   // See kernel for details
   CHECK_LE(CAFFE_CUDA_NUM_THREADS_HALF, 512);
@@ -159,7 +161,7 @@ template<>
 void caffe_gpu_amax<float16>(const int n, const float16* x, float* y, int group) {
   // For odd counts we allocate extra element to speed up kernels.
   // We have to keep it clean.
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   if (n & 1) {
     clean_last_element(const_cast<float16*>(x) + n, stream);
   }
index 11423a414966252b35d323f6021a98bd193ab848..cfe665a9ea3980d20e50a21ccf380bee3eadb284 100644 (file)
@@ -10,37 +10,39 @@ namespace caffe {
 SHMEM(asum);
 CAFFE_GPU_SHMEM(asum);
 
+#define BLOCK_REDUCE_ASUM(TNUM) \
+if (BlockSize >= (TNUM) * 2) { \
+  if (tid < (TNUM)) { \
+    tsum_replace(st, sdata[tid + (TNUM)]); \
+  } \
+  __syncthreads(); \
+}
+
+#define REDUCE_ASUM(TNUM) \
+if (tid + (TNUM) < thread_count) { \
+  tsum_replace(st, sdata[tid + (TNUM)]); \
+  __syncthreads(); \
+}
+
 ///////////////////////////////////// ASUM REDUCTION ///////////////////////////////////
 
 template<unsigned int BlockSize, typename TR>
 __device__ void asum_reduce_block(volatile TR *sdata, TR my_sum, unsigned int tid) {
+  const int thread_count = blockDim.x * blockDim.y * blockDim.z;
   volatile TR* st = sdata + tid;
   tassign(st, my_sum);
   __syncthreads();
-
   // do reduction in shared mem
-  if (BlockSize >= 512) {
-    if (tid < 256) {
-      tsum_replace(st, sdata[tid + 256]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 256) {
-    if (tid < 128) {
-      tsum_replace(st, sdata[tid + 128]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 128) {
-    if (tid < 64) {
-      tsum_replace(st, sdata[tid + 64]);
-    }
-    __syncthreads();
-  }
+  BLOCK_REDUCE_ASUM(256)
+  BLOCK_REDUCE_ASUM(128)
+  BLOCK_REDUCE_ASUM(64)
   if (tid < 32) {
-    for (int i = 32; i > 0; i >>= 1) {
-      tsum_replace(st, sdata[tid + i]);
-    }
+    REDUCE_ASUM(32)
+    REDUCE_ASUM(16)
+    REDUCE_ASUM(8)
+    REDUCE_ASUM(4)
+    REDUCE_ASUM(2)
+    REDUCE_ASUM(1)
   }
 }
 
@@ -124,7 +126,7 @@ __global__ void asum_reduce_kernel(unsigned int n, const T *in, TR *out, int gro
 template<typename T, typename TR>
 void gpu_asum_t(const int n, const T* x, TR* sum, int group) {
   CHECK_LT(group, REGRESSION_GROUPS_MAX);
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   const bool po2 = is_pow2(n);
   // See kernel for details
   CHECK_LE(CAFFE_CUDA_NUM_THREADS_HALF, 512);
@@ -155,7 +157,7 @@ template<>
 void caffe_gpu_asum<float16, float>(const int n, const float16* x, float* sum, int group) {
   // For odd counts we allocate extra element to speed up kernels.
   // We have to keep it clean.
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   if (n & 1) {
     clean_last_element(const_cast<float16*>(x) + n, stream);
   }
index 8282303f1e12d58ad55bb1880a82afc9cbcbd222..11cd1b3e8219cf5b4a50318d632bd39863d59da7 100644 (file)
@@ -10,37 +10,40 @@ namespace caffe {
 SHMEM(sq);
 CAFFE_GPU_SHMEM(sq);
 
+#define BLOCK_REDUCE_ASUM(TNUM) \
+if (BlockSize >= (TNUM) * 2) { \
+  if (tid < (TNUM)) { \
+    tsum_replace(st, sdata[tid + (TNUM)]); \
+  } \
+  __syncthreads(); \
+}
+
+#define REDUCE_ASUM(TNUM) \
+if (tid + (TNUM) < thread_count) { \
+  tsum_replace(st, sdata[tid + (TNUM)]); \
+  __syncthreads(); \
+}
+
+
 ///////////////////////////////////// SUMSQ REDUCTION ///////////////////////////////////
 
 template<unsigned int BlockSize, typename TR>
 __device__ void sumsq_reduce_block(volatile TR *sdata, TR my_sum, unsigned int tid) {
+  const int thread_count = blockDim.x * blockDim.y * blockDim.z;
   volatile TR* st = sdata + tid;
   tassign(st, my_sum);
   __syncthreads();
-
   // do reduction in shared mem
-  if (BlockSize >= 512) {
-    if (tid < 256) {
-      tsum_replace(st, sdata[tid + 256]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 256) {
-    if (tid < 128) {
-      tsum_replace(st, sdata[tid + 128]);
-    }
-    __syncthreads();
-  }
-  if (BlockSize >= 128) {
-    if (tid < 64) {
-      tsum_replace(st, sdata[tid + 64]);
-    }
-    __syncthreads();
-  }
+  BLOCK_REDUCE_ASUM(256)
+  BLOCK_REDUCE_ASUM(128)
+  BLOCK_REDUCE_ASUM(64)
   if (tid < 32) {
-    for (int i = 32; i > 0; i >>= 1) {
-      tsum_replace(st, sdata[tid + i]);
-    }
+    REDUCE_ASUM(32)
+    REDUCE_ASUM(16)
+    REDUCE_ASUM(8)
+    REDUCE_ASUM(4)
+    REDUCE_ASUM(2)
+    REDUCE_ASUM(1)
   }
 }
 
@@ -124,7 +127,7 @@ __global__ void sumsq_reduce_kernel(unsigned int n, const T *in, TR *out, int gr
 template<typename T, typename TR>
 void gpu_sumsq_t(const int n, const T* x, TR* sum, int group) {
   CHECK_LT(group, REGRESSION_GROUPS_MAX);
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   const bool po2 = is_pow2(n);
   // See kernel for details
   CHECK_LE(CAFFE_CUDA_NUM_THREADS_HALF, 512);
@@ -155,7 +158,7 @@ template<>
 void caffe_gpu_sumsq<float16, float>(const int n, const float16* x, float* sum, int group) {
   // For odd counts we allocate extra element to speed up kernels.
   // We have to keep it clean.
-  cudaStream_t stream = Caffe::thread_stream();
+  cudaStream_t stream = Caffe::thread_stream(group);
   if (n & 1) {
     clean_last_element(const_cast<float16*>(x) + n, stream);
   }