]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/blob.cpp
proto update
[jacinto-ai/caffe-jacinto.git] / src / caffe / blob.cpp
1 // Copyright 2013 Yangqing Jia
3 #include <cublas_v2.h>
5 #include "caffe/blob.hpp"
6 #include "caffe/common.hpp"
7 #include "caffe/syncedmem.hpp"
9 namespace caffe {
11 template <typename Dtype>
12 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
13     const int width) {
14   CHECK_GT(num, 0);
15   CHECK_GT(channels, 0);
16   CHECK_GT(height, 0);
17   CHECK_GT(width, 0);
18   num_ = num;
19   channels_ = channels;
20   height_ = height;
21   width_ = width;
22   count_ = num_ * channels_ * height_ * width_;
23   data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
24   diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
25 }
27 template <typename Dtype>
28 Blob<Dtype>::Blob(const int num, const int channels, const int height,
29     const int width) {
30   Reshape(num, channels, height, width);
31 }
33 template <typename Dtype>
34 Blob<Dtype>::Blob(const Blob<Dtype>& source) {
35   if (source.count() == 0) {
36     Blob();
37   } else {
38     Reshape(source.num(), source.channels(), source.height(),
39         source.width());
40     // create the synced memories.
41     data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
42     diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
43     // Copy the data.
44     memcpy(data_->mutable_cpu_data(), source.cpu_data(),
45         count_ * sizeof(Dtype));
46     memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
47         count_ * sizeof(Dtype));
48   }
49 }
51 template <typename Dtype>
52 const Dtype* Blob<Dtype>::cpu_data() const {
53   CHECK(data_);
54   return (const Dtype*)data_->cpu_data();
55 }
57 template <typename Dtype>
58 const Dtype* Blob<Dtype>::gpu_data() const {
59   CHECK(data_);
60   return (const Dtype*)data_->gpu_data();
61 }
63 template <typename Dtype>
64 const Dtype* Blob<Dtype>::cpu_diff() const {
65   CHECK(diff_);
66   return (const Dtype*)diff_->cpu_data();
67 }
69 template <typename Dtype>
70 const Dtype* Blob<Dtype>::gpu_diff() const {
71   CHECK(diff_);
72   return (const Dtype*)diff_->gpu_data();
73 }
75 template <typename Dtype>
76 Dtype* Blob<Dtype>::mutable_cpu_data() {
77   CHECK(data_);
78   return reinterpret_cast<Dtype*>(data_->mutable_cpu_data());
79 }
81 template <typename Dtype>
82 Dtype* Blob<Dtype>::mutable_gpu_data() {
83   CHECK(data_);
84   return reinterpret_cast<Dtype*>(data_->mutable_gpu_data());
85 }
87 template <typename Dtype>
88 Dtype* Blob<Dtype>::mutable_cpu_diff() {
89   CHECK(diff_);
90   return reinterpret_cast<Dtype*>(diff_->mutable_cpu_data());
91 }
93 template <typename Dtype>
94 Dtype* Blob<Dtype>::mutable_gpu_diff() {
95   CHECK(diff_);
96   return reinterpret_cast<Dtype*>(diff_->mutable_gpu_data());
97 }
99 template <typename Dtype>
100 void Blob<Dtype>::Update() {
101   // not implemented yet.
102   LOG(FATAL) << "not implemented";
105 template <typename Dtype>
106 void Blob<Dtype>::FromProto(const BlobProto& proto) {
107   Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
108   // copy data
109   Dtype* data_vec = mutable_cpu_data();
110   for (int i = 0; i < count_; ++i) {
111     data_vec[i] = proto.data(i);
112   }
113   Dtype* diff_vec = mutable_cpu_diff();
114   for (int i = 0; i < count_; ++i) {
115     diff_vec[i] = proto.diff(i);
116   }
119 template <typename Dtype>
120 void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) {
121   proto->set_num(num_);
122   proto->set_channels(channels_);
123   proto->set_height(height_);
124   proto->set_width(width_);
125   proto->clear_data();
126   proto->clear_diff();
127   const Dtype* data_vec = cpu_data();
128   for (int i = 0; i < count_; ++i) {
129     proto->add_data(data_vec[i]);
130   }
131   if (write_diff) {
132     const Dtype* diff_vec = cpu_diff();
133     for (int i = 0; i < count_; ++i) {
134       proto->add_diff(diff_vec[i]);
135     }
136   }
139 INSTANTIATE_CLASS(Blob);
141 }  // namespace caffe