summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'libappfuse')
-rw-r--r--libappfuse/FuseBuffer.cc65
-rw-r--r--libappfuse/include/libappfuse/FuseBuffer.h4
-rw-r--r--libappfuse/tests/FuseBufferTest.cc26
3 files changed, 64 insertions, 31 deletions
diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc
index 882d54552..8fb2dbcc5 100644
--- a/libappfuse/FuseBuffer.cc
+++ b/libappfuse/FuseBuffer.cc
@@ -23,6 +23,7 @@
23#include <algorithm> 23#include <algorithm>
24#include <type_traits> 24#include <type_traits>
25 25
26#include <android-base/file.h>
26#include <android-base/logging.h> 27#include <android-base/logging.h>
27#include <android-base/macros.h> 28#include <android-base/macros.h>
28 29
@@ -34,57 +35,65 @@ static_assert(
34 "FuseBuffer must be standard layout union."); 35 "FuseBuffer must be standard layout union.");
35 36
36template <typename T> 37template <typename T>
37bool FuseMessage<T>::CheckPacketSize(size_t size, const char* name) const { 38bool FuseMessage<T>::CheckHeaderLength(const char* name) const {
38 const auto& header = static_cast<const T*>(this)->header; 39 const auto& header = static_cast<const T*>(this)->header;
39 if (size >= sizeof(header) && size <= sizeof(T)) { 40 if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
40 return true; 41 return true;
41 } else { 42 } else {
42 LOG(ERROR) << name << " is invalid=" << size; 43 LOG(ERROR) << "Invalid header length is found in " << name << ": " <<
44 header.len;
43 return false; 45 return false;
44 } 46 }
45} 47}
46 48
47template <typename T> 49template <typename T>
48bool FuseMessage<T>::CheckResult(int result, const char* operation_name) const { 50bool FuseMessage<T>::Read(int fd) {
49 if (result == 0) { 51 char* const buf = reinterpret_cast<char*>(this);
50 // Expected close of other endpoints. 52 const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T)));
51 return false;
52 }
53 if (result < 0) { 53 if (result < 0) {
54 PLOG(ERROR) << "Failed to " << operation_name << " a packet"; 54 PLOG(ERROR) << "Failed to read a FUSE message";
55 return false; 55 return false;
56 } 56 }
57 return true;
58}
59 57
60template <typename T>
61bool FuseMessage<T>::CheckHeaderLength(int result, const char* operation_name) const {
62 const auto& header = static_cast<const T*>(this)->header; 58 const auto& header = static_cast<const T*>(this)->header;
63 if (static_cast<uint32_t>(result) == header.len) { 59 if (result < static_cast<ssize_t>(sizeof(header))) {
64 return true; 60 LOG(ERROR) << "Read bytes " << result << " are shorter than header size " <<
65 } else { 61 sizeof(header);
66 LOG(ERROR) << "Invalid header length: operation_name=" << operation_name
67 << " result=" << result
68 << " header.len=" << header.len;
69 return false; 62 return false;
70 } 63 }
71}
72 64
73template <typename T> 65 if (!CheckHeaderLength("Read")) {
74bool FuseMessage<T>::Read(int fd) { 66 return false;
75 const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, this, sizeof(T))); 67 }
76 return CheckResult(result, "read") && CheckPacketSize(result, "read count") && 68
77 CheckHeaderLength(result, "read"); 69 if (static_cast<uint32_t>(result) > header.len) {
70 LOG(ERROR) << "Read bytes " << result << " are longer than header.len " <<
71 header.len;
72 return false;
73 }
74
75 if (!base::ReadFully(fd, buf + result, header.len - result)) {
76 PLOG(ERROR) << "ReadFully failed";
77 return false;
78 }
79
80 return true;
78} 81}
79 82
80template <typename T> 83template <typename T>
81bool FuseMessage<T>::Write(int fd) const { 84bool FuseMessage<T>::Write(int fd) const {
85 if (!CheckHeaderLength("Write")) {
86 return false;
87 }
88
89 const char* const buf = reinterpret_cast<const char*>(this);
82 const auto& header = static_cast<const T*>(this)->header; 90 const auto& header = static_cast<const T*>(this)->header;
83 if (!CheckPacketSize(header.len, "header.len")) { 91 if (!base::WriteFully(fd, buf, header.len)) {
92 PLOG(ERROR) << "WriteFully failed";
84 return false; 93 return false;
85 } 94 }
86 const ssize_t result = TEMP_FAILURE_RETRY(::write(fd, this, header.len)); 95
87 return CheckResult(result, "write") && CheckHeaderLength(result, "write"); 96 return true;
88} 97}
89 98
90template class FuseMessage<FuseRequest>; 99template class FuseMessage<FuseRequest>;
diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h
index 276db9020..7abd2fa40 100644
--- a/libappfuse/include/libappfuse/FuseBuffer.h
+++ b/libappfuse/include/libappfuse/FuseBuffer.h
@@ -34,9 +34,7 @@ class FuseMessage {
34 bool Read(int fd); 34 bool Read(int fd);
35 bool Write(int fd) const; 35 bool Write(int fd) const;
36 private: 36 private:
37 bool CheckPacketSize(size_t size, const char* name) const; 37 bool CheckHeaderLength(const char* name) const;
38 bool CheckResult(int result, const char* operation_name) const;
39 bool CheckHeaderLength(int result, const char* operation_name) const;
40}; 38};
41 39
42// FuseRequest represents file operation requests from /dev/fuse. It starts 40// FuseRequest represents file operation requests from /dev/fuse. It starts
diff --git a/libappfuse/tests/FuseBufferTest.cc b/libappfuse/tests/FuseBufferTest.cc
index c82213513..db35d330d 100644
--- a/libappfuse/tests/FuseBufferTest.cc
+++ b/libappfuse/tests/FuseBufferTest.cc
@@ -20,6 +20,8 @@
20#include <string.h> 20#include <string.h>
21#include <sys/socket.h> 21#include <sys/socket.h>
22 22
23#include <thread>
24
23#include <android-base/unique_fd.h> 25#include <android-base/unique_fd.h>
24#include <gtest/gtest.h> 26#include <gtest/gtest.h>
25 27
@@ -110,6 +112,30 @@ TEST(FuseMessageTest, Write_TooShort) {
110 TestWriteInvalidLength(sizeof(fuse_in_header) - 1); 112 TestWriteInvalidLength(sizeof(fuse_in_header) - 1);
111} 113}
112 114
115TEST(FuseMessageTest, ShortWriteAndRead) {
116 int raw_fds[2];
117 ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds));
118
119 android::base::unique_fd fds[2];
120 fds[0].reset(raw_fds[0]);
121 fds[1].reset(raw_fds[1]);
122
123 const int send_buffer_size = 1024;
124 ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size,
125 sizeof(int)));
126
127 bool succeed = false;
128 const int sender_fd = fds[0].get();
129 std::thread thread([sender_fd, &succeed] {
130 FuseRequest request;
131 request.header.len = 1024 * 4;
132 succeed = request.Write(sender_fd);
133 });
134 thread.detach();
135 FuseRequest request;
136 ASSERT_TRUE(request.Read(fds[1]));
137}
138
113TEST(FuseResponseTest, Reset) { 139TEST(FuseResponseTest, Reset) {
114 FuseResponse response; 140 FuseResponse response;
115 // Write 1 to the first ten bytes. 141 // Write 1 to the first ten bytes.