a bunch of updates.
[jacinto-ai/caffe-jacinto.git] / src / caffe / util / io.cpp
1 // Copyright 2013 Yangqing Jia
3 #include <stdint.h>
4 #include <fcntl.h>
5 #include <google/protobuf/text_format.h>
6 #include <google/protobuf/io/zero_copy_stream_impl.h>
7 #include <opencv2/core/core.hpp>
8 #include <opencv2/highgui/highgui.hpp>
10 #include <algorithm>
11 #include <string>
13 #include "caffe/common.hpp"
14 #include "caffe/util/io.hpp"
15 #include "caffe/proto/caffe.pb.h"
17 using cv::Mat;
18 using cv::Vec3b;
19 using std::max;
20 using std::string;
21 using google::protobuf::io::FileInputStream;
23 namespace caffe {
25 void ReadImageToProto(const string& filename, BlobProto* proto) {
26   Mat cv_img;
27   cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
28   CHECK(cv_img.data) << "Could not open or find the image.";
29   DCHECK_EQ(cv_img.channels(), 3);
30   proto->set_num(1);
31   proto->set_channels(3);
32   proto->set_height(cv_img.rows);
33   proto->set_width(cv_img.cols);
34   proto->clear_data();
35   proto->clear_diff();
36   for (int c = 0; c < 3; ++c) {
37     for (int h = 0; h < cv_img.rows; ++h) {
38       for (int w = 0; w < cv_img.cols; ++w) {
39         proto->add_data(static_cast<float>(cv_img.at<Vec3b>(h, w)[c]) / 255.);
40       }
41     }
42   }
43 }
45 void WriteProtoToImage(const string& filename, const BlobProto& proto) {
46   CHECK_EQ(proto.num(), 1);
47   CHECK(proto.channels() == 3 || proto.channels() == 1);
48   CHECK_GT(proto.height(), 0);
49   CHECK_GT(proto.width(), 0);
50   Mat cv_img(proto.height(), proto.width(), CV_8UC3);
51   if (proto.channels() == 1) {
52     for (int c = 0; c < 3; ++c) {
53       for (int h = 0; h < cv_img.rows; ++h) {
54         for (int w = 0; w < cv_img.cols; ++w) {
55           cv_img.at<Vec3b>(h, w)[c] =
56               uint8_t(proto.data(h * cv_img.cols + w) * 255.);
57         }
58       }
59     }
60   } else {
61     for (int c = 0; c < 3; ++c) {
62       for (int h = 0; h < cv_img.rows; ++h) {
63         for (int w = 0; w < cv_img.cols; ++w) {
64           cv_img.at<Vec3b>(h, w)[c] =
65               uint8_t(proto.data((c * cv_img.rows + h) * cv_img.cols + w)
66                   * 255.);
67         }
68       }
69     }
70   }
71   CHECK(cv::imwrite(filename, cv_img));
72 }
74 void ReadProtoFromTextFile(const char* filename,
75     ::google::protobuf::Message* proto) {
76   int fd = open(filename, O_RDONLY);
77   FileInputStream* input = new FileInputStream(fd);
78   CHECK(google::protobuf::TextFormat::Parse(input, proto));
79   delete input;
80   close(fd);
81 }
83 }  // namespace caffe