616274018cc9c1ba23101315522eaab9861a5da9
1 // Copyright 2013 Yangqing Jia
3 #include <cublas_v2.h>
5 #include "caffe/blob.hpp"
6 #include "caffe/common.hpp"
7 #include "caffe/syncedmem.hpp"
9 namespace caffe {
11 template <typename Dtype>
12 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
13 const int width) {
14 CHECK_GE(num, 0);
15 CHECK_GE(channels, 0);
16 CHECK_GE(height, 0);
17 CHECK_GE(width, 0);
18 num_ = num;
19 channels_ = channels;
20 height_ = height;
21 width_ = width;
22 count_ = num_ * channels_ * height_ * width_;
23 if (count_) {
24 data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
25 diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
26 } else {
27 data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
28 diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
29 }
30 }
32 template <typename Dtype>
33 Blob<Dtype>::Blob(const int num, const int channels, const int height,
34 const int width) {
35 Reshape(num, channels, height, width);
36 }
38 template <typename Dtype>
39 const Dtype* Blob<Dtype>::cpu_data() const {
40 CHECK(data_);
41 return (const Dtype*)data_->cpu_data();
42 }
44 template <typename Dtype>
45 const Dtype* Blob<Dtype>::gpu_data() const {
46 CHECK(data_);
47 return (const Dtype*)data_->gpu_data();
48 }
50 template <typename Dtype>
51 const Dtype* Blob<Dtype>::cpu_diff() const {
52 CHECK(diff_);
53 return (const Dtype*)diff_->cpu_data();
54 }
56 template <typename Dtype>
57 const Dtype* Blob<Dtype>::gpu_diff() const {
58 CHECK(diff_);
59 return (const Dtype*)diff_->gpu_data();
60 }
62 template <typename Dtype>
63 Dtype* Blob<Dtype>::mutable_cpu_data() {
64 CHECK(data_);
65 return reinterpret_cast<Dtype*>(data_->mutable_cpu_data());
66 }
68 template <typename Dtype>
69 Dtype* Blob<Dtype>::mutable_gpu_data() {
70 CHECK(data_);
71 return reinterpret_cast<Dtype*>(data_->mutable_gpu_data());
72 }
74 template <typename Dtype>
75 Dtype* Blob<Dtype>::mutable_cpu_diff() {
76 CHECK(diff_);
77 return reinterpret_cast<Dtype*>(diff_->mutable_cpu_data());
78 }
80 template <typename Dtype>
81 Dtype* Blob<Dtype>::mutable_gpu_diff() {
82 CHECK(diff_);
83 return reinterpret_cast<Dtype*>(diff_->mutable_gpu_data());
84 }
86 template <typename Dtype>
87 void Blob<Dtype>::Update() {
88 // not implemented yet.
89 LOG(FATAL) << "not implemented";
90 // We will perform update based on where the data is located.
91 }
93 template <typename Dtype>
94 void Blob<Dtype>::FromProto(const BlobProto& proto) {
95 Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
96 // copy data
97 Dtype* data_vec = mutable_cpu_data();
98 for (int i = 0; i < count_; ++i) {
99 data_vec[i] = proto.data(i);
100 }
101 if (proto.diff_size() > 0) {
102 Dtype* diff_vec = mutable_cpu_diff();
103 for (int i = 0; i < count_; ++i) {
104 diff_vec[i] = proto.diff(i);
105 }
106 }
107 }
109 template <typename Dtype>
110 void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
111 proto->set_num(num_);
112 proto->set_channels(channels_);
113 proto->set_height(height_);
114 proto->set_width(width_);
115 proto->clear_data();
116 proto->clear_diff();
117 const Dtype* data_vec = cpu_data();
118 for (int i = 0; i < count_; ++i) {
119 proto->add_data(data_vec[i]);
120 }
121 if (write_diff) {
122 const Dtype* diff_vec = cpu_diff();
123 for (int i = 0; i < count_; ++i) {
124 proto->add_diff(diff_vec[i]);
125 }
126 }
127 }
129 INSTANTIATE_CLASS(Blob);
131 } // namespace caffe