updated a bunch of things, ready to test if it breaks things
[jacinto-ai/caffe-jacinto.git] / src / caffe / blob.cpp
index 616274018cc9c1ba23101315522eaab9861a5da9..35e5b04adbddc82585733defb6eb5d25529ef9c5 100644 (file)
@@ -1,5 +1,6 @@
 // Copyright 2013 Yangqing Jia
 
+#include <cuda_runtime.h>
 #include <cublas_v2.h>
 
 #include "caffe/blob.hpp"
@@ -11,6 +12,7 @@ namespace caffe {
 template <typename Dtype>
 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
     const int width) {
+  int old_count = count_;
   CHECK_GE(num, 0);
   CHECK_GE(channels, 0);
   CHECK_GE(height, 0);
@@ -21,8 +23,10 @@ void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
   width_ = width;
   count_ = num_ * channels_ * height_ * width_;
   if (count_) {
-    data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
-    diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    if (old_count != count_) {
+      data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+      diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    }
   } else {
     data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
     diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
@@ -90,6 +94,40 @@ void Blob<Dtype>::Update() {
   // We will perform update based on where the data is located.
 }
 
+template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
+  if (num_ != source.num() || channels_ != source.channels() ||
+      height_ != source.height() || width_ != source.width()) {
+    if (reshape) {
+      Reshape(source.num(), source.channels(), source.height(), source.width());
+    } else {
+      LOG(FATAL) << "Trying to copy blobs of different sizes.";
+    }
+  }
+  switch (Caffe::mode()) {
+  case Caffe::GPU:
+    if (copy_diff) {
+      CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
+          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+    } else {
+      CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
+          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+    }
+    break;
+  case Caffe::CPU:
+    if (copy_diff) {
+      memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+          sizeof(Dtype) * count_);
+    } else {
+      memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+        sizeof(Dtype) * count_);
+    }
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode.";
+  }
+}
+
 template <typename Dtype>
 void Blob<Dtype>::FromProto(const BlobProto& proto) {
   Reshape(proto.num(), proto.channels(), proto.height(), proto.width());