]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
Sync shuffling
authorSergei Nikolaev <snikolaev@nvidia.com>
Tue, 26 Dec 2017 05:17:24 +0000 (21:17 -0800)
committerSergei Nikolaev <snikolaev@nvidia.com>
Tue, 26 Dec 2017 05:17:24 +0000 (21:17 -0800)
include/caffe/data_reader.hpp
src/caffe/data_reader.cpp
src/caffe/layers/data_layer.cpp

index 85f20153093d00984c69854553cd5755c56f90f7..b8389c3bb73e4a0c2470186011624a91216e14c7 100644 (file)
@@ -159,7 +159,7 @@ class DataReader : public InternalThread {
     data_cache_->just_cached();
   }
 
-  std::mutex& shuffle_mutex() {
+  static shared_mutex& shuffle_mutex() {
     return shuffle_mutex_;
   }
 
@@ -174,7 +174,7 @@ protected:
   size_t batch_size_;
   const bool skip_one_batch_;
   DataParameter_DB backend_;
-  std::mutex shuffle_mutex_;
+  static shared_mutex shuffle_mutex_;
 
   shared_ptr<BlockingQueue<shared_ptr<Datum>>> init_;
   vector<shared_ptr<BlockingQueue<shared_ptr<Datum>>>> free_;
index b624883c7763641f93387c89279f1b744f355ee4..797718ac751a4554a316d1965017bcae461306bb 100644 (file)
@@ -10,6 +10,7 @@ namespace caffe {
 
 std::mutex DataReader::DataCache::cache_mutex_;
 unique_ptr<DataReader::DataCache> DataReader::DataCache::data_cache_inst_;
+shared_mutex DataReader::shuffle_mutex_;
 
 DataReader::DataReader(const LayerParameter& param,
     size_t solver_count,
@@ -158,8 +159,7 @@ shared_ptr<Datum>& DataReader::DataCache::next_cached(DataReader& reader) {
   std::lock_guard<std::mutex> lock(cache_mutex_);
   if (shuffle_ && cache_idx_== 0UL) {
     LOG(INFO) << "Shuffling " << cache_buffer_.size() << " records...";
-    // Every epoch we might shuffle
-    std::lock_guard<std::mutex> lock(reader.shuffle_mutex());
+    unique_lock<shared_mutex> shfllock(DataReader::shuffle_mutex());
     caffe::shuffle(cache_buffer_.begin(), cache_buffer_.end());
   }
   shared_ptr<Datum>& datum = cache_buffer_[cache_idx_++];
index fe39cc54418eb11524c9f1d02eea3919b2d6838a..36c86b1557bf4b51504a49707cc68367cda977b7 100644 (file)
@@ -284,8 +284,8 @@ void DataLayer<Ftype, Btype>::load_batch(Batch* batch, int thread_id, size_t que
   Btype* dst_cptr = nullptr;
   if (use_gpu_transform) {
 #ifndef CPU_ONLY
-    size_t holder_size = top_shape[0] * top_shape[1] * init_datum_height * init_datum_width;
-    tmp_gpu_buffer_[thread_id]->safe_reserve(holder_size);
+    size_t buffer_size = top_shape[0] * top_shape[1] * init_datum_height * init_datum_width;
+    tmp_gpu_buffer_[thread_id]->safe_reserve(buffer_size);
     dst_gptr = tmp_gpu_buffer_[thread_id]->data();
 #endif
   } else {
@@ -307,28 +307,23 @@ void DataLayer<Ftype, Btype>::load_batch(Batch* batch, int thread_id, size_t que
 
     if (use_gpu_transform) {
 #ifndef CPU_ONLY
+      // Every epoch we might shuffle
+      shared_lock<shared_mutex> lock(DataReader::shuffle_mutex());
       if (datum->encoded()) {
         DecodeDatumToSignedBuf(*datum, color_mode,
             &src_buf[src_buf_pos * datum_size], datum_size, false);
       } else {
         CHECK_EQ(datum_len, datum->channels() * datum->height() * datum->width())
           << "Datum size can't vary in the same batch";
-        // Every epoch we might shuffle
-        std::unique_ptr<std::lock_guard<std::mutex>> lock;
-        if (reader_) {
-          lock.reset(new std::lock_guard<std::mutex>(reader_->shuffle_mutex()));
-        } else if (sample_reader_) {
-          lock.reset(new std::lock_guard<std::mutex>(sample_reader_->shuffle_mutex()));
-        }
         src_ptr = datum->data().size() > 0 ?
                   &datum->data().front() :
                   reinterpret_cast<const char*>(&datum->float_data().Get(0));
-        std::memcpy(src_buf.data() +  // NOLINT(caffe/alt_fn)
-          src_buf_pos * datum_size, src_ptr, datum_size);  // NOLINT(caffe/alt_fn)
+        // NOLINT_NEXT_LINE(caffe/alt_fn)
+        std::memcpy(&src_buf[src_buf_pos * datum_size], src_ptr, datum_size);
       }
       ++src_buf_pos;
       if (src_buf_pos == src_buf_items) {
-        src_buf_pos = 0;
+        src_buf_pos = 0UL;
         CUDA_CHECK(cudaMemcpyAsync(
             reinterpret_cast<char*>(dst_gptr) + last_item_id * datum_size,
             src_buf.data(), src_buf_size, cudaMemcpyHostToDevice, stream));