1 // Copyright 2013 Yangqing Jia
3 #include <cuda_runtime.h>
4 #include <cublas_v2.h>
6 #include "caffe/blob.hpp"
7 #include "caffe/common.hpp"
8 #include "caffe/syncedmem.hpp"
10 namespace caffe {
12 template <typename Dtype>
13 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
14 const int width) {
15 int old_count = count_;
16 CHECK_GE(num, 0);
17 CHECK_GE(channels, 0);
18 CHECK_GE(height, 0);
19 CHECK_GE(width, 0);
20 num_ = num;
21 channels_ = channels;
22 height_ = height;
23 width_ = width;
24 count_ = num_ * channels_ * height_ * width_;
25 if (count_) {
26 if (old_count != count_) {
27 data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
28 diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
29 }
30 } else {
31 data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
32 diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
33 }
34 }
36 template <typename Dtype>
37 Blob<Dtype>::Blob(const int num, const int channels, const int height,
38 const int width) {
39 Reshape(num, channels, height, width);
40 }
42 template <typename Dtype>
43 const Dtype* Blob<Dtype>::cpu_data() const {
44 CHECK(data_);
45 return (const Dtype*)data_->cpu_data();
46 }
48 template <typename Dtype>
49 const Dtype* Blob<Dtype>::gpu_data() const {
50 CHECK(data_);
51 return (const Dtype*)data_->gpu_data();
52 }
54 template <typename Dtype>
55 const Dtype* Blob<Dtype>::cpu_diff() const {
56 CHECK(diff_);
57 return (const Dtype*)diff_->cpu_data();
58 }
60 template <typename Dtype>
61 const Dtype* Blob<Dtype>::gpu_diff() const {
62 CHECK(diff_);
63 return (const Dtype*)diff_->gpu_data();
64 }
66 template <typename Dtype>
67 Dtype* Blob<Dtype>::mutable_cpu_data() {
68 CHECK(data_);
69 return reinterpret_cast<Dtype*>(data_->mutable_cpu_data());
70 }
72 template <typename Dtype>
73 Dtype* Blob<Dtype>::mutable_gpu_data() {
74 CHECK(data_);
75 return reinterpret_cast<Dtype*>(data_->mutable_gpu_data());
76 }
78 template <typename Dtype>
79 Dtype* Blob<Dtype>::mutable_cpu_diff() {
80 CHECK(diff_);
81 return reinterpret_cast<Dtype*>(diff_->mutable_cpu_data());
82 }
84 template <typename Dtype>
85 Dtype* Blob<Dtype>::mutable_gpu_diff() {
86 CHECK(diff_);
87 return reinterpret_cast<Dtype*>(diff_->mutable_gpu_data());
88 }
90 template <typename Dtype>
91 void Blob<Dtype>::Update() {
92 // not implemented yet.
93 LOG(FATAL) << "not implemented";
94 // We will perform update based on where the data is located.
95 }
97 template <typename Dtype>
98 void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
99 if (num_ != source.num() || channels_ != source.channels() ||
100 height_ != source.height() || width_ != source.width()) {
101 if (reshape) {
102 Reshape(source.num(), source.channels(), source.height(), source.width());
103 } else {
104 LOG(FATAL) << "Trying to copy blobs of different sizes.";
105 }
106 }
107 switch (Caffe::mode()) {
108 case Caffe::GPU:
109 if (copy_diff) {
110 CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
111 sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
112 } else {
113 CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
114 sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
115 }
116 break;
117 case Caffe::CPU:
118 if (copy_diff) {
119 memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
120 sizeof(Dtype) * count_);
121 } else {
122 memcpy(data_->mutable_cpu_data(), source.cpu_data(),
123 sizeof(Dtype) * count_);
124 }
125 break;
126 default:
127 LOG(FATAL) << "Unknown caffe mode.";
128 }
129 }
131 template <typename Dtype>
132 void Blob<Dtype>::FromProto(const BlobProto& proto) {
133 Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
134 // copy data
135 Dtype* data_vec = mutable_cpu_data();
136 for (int i = 0; i < count_; ++i) {
137 data_vec[i] = proto.data(i);
138 }
139 if (proto.diff_size() > 0) {
140 Dtype* diff_vec = mutable_cpu_diff();
141 for (int i = 0; i < count_; ++i) {
142 diff_vec[i] = proto.diff(i);
143 }
144 }
145 }
147 template <typename Dtype>
148 void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
149 proto->set_num(num_);
150 proto->set_channels(channels_);
151 proto->set_height(height_);
152 proto->set_width(width_);
153 proto->clear_data();
154 proto->clear_diff();
155 const Dtype* data_vec = cpu_data();
156 for (int i = 0; i < count_; ++i) {
157 proto->add_data(data_vec[i]);
158 }
159 if (write_diff) {
160 const Dtype* diff_vec = cpu_diff();
161 for (int i = 0; i < count_; ++i) {
162 proto->add_diff(diff_vec[i]);
163 }
164 }
165 }
167 INSTANTIATE_CLASS(Blob);
169 } // namespace caffe