diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 8260fe0f99c9e8791725871cc4e5d359c00deb2e..35e5b04adbddc82585733defb6eb5d25529ef9c5 100644 (file)
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
// Copyright 2013 Yangqing Jia
+#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "caffe/blob.hpp"
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
+ int old_count = count_;
CHECK_GE(num, 0);
CHECK_GE(channels, 0);
CHECK_GE(height, 0);
width_ = width;
count_ = num_ * channels_ * height_ * width_;
if (count_) {
- data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ if (old_count != count_) {
+ data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ }
} else {
- data_.reset((SyncedMemory*)NULL);
- diff_.reset((SyncedMemory*)NULL);
+ data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
+ diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
}
}
// We will perform update based on where the data is located.
}
+template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
+ if (num_ != source.num() || channels_ != source.channels() ||
+ height_ != source.height() || width_ != source.width()) {
+ if (reshape) {
+ Reshape(source.num(), source.channels(), source.height(), source.width());
+ } else {
+ LOG(FATAL) << "Trying to copy blobs of different sizes.";
+ }
+ }
+ switch (Caffe::mode()) {
+ case Caffe::GPU:
+ if (copy_diff) {
+ CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
+ sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+ } else {
+ CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
+ sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+ }
+ break;
+ case Caffe::CPU:
+ if (copy_diff) {
+ memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+ sizeof(Dtype) * count_);
+ } else {
+ memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+ sizeof(Dtype) * count_);
+ }
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode.";
+ }
+}
+
template <typename Dtype>
void Blob<Dtype>::FromProto(const BlobProto& proto) {
Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
}
template <typename Dtype>
-void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) {
+void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
proto->set_num(num_);
proto->set_channels(channels_);
proto->set_height(height_);