summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to 'libappfuse')
-rw-r--r--libappfuse/FuseBuffer.cc157
-rw-r--r--libappfuse/include/libappfuse/FuseBuffer.h17
-rw-r--r--libappfuse/tests/FuseAppLoopTest.cc5
-rw-r--r--libappfuse/tests/FuseBridgeLoopTest.cc11
-rw-r--r--libappfuse/tests/FuseBufferTest.cc48
5 files changed, 148 insertions, 90 deletions
diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc
index 8fb2dbcc5..13cfc88ec 100644
--- a/libappfuse/FuseBuffer.cc
+++ b/libappfuse/FuseBuffer.cc
@@ -23,77 +23,132 @@
23#include <algorithm> 23#include <algorithm>
24#include <type_traits> 24#include <type_traits>
25 25
26#include <sys/socket.h>
27
26#include <android-base/file.h> 28#include <android-base/file.h>
27#include <android-base/logging.h> 29#include <android-base/logging.h>
28#include <android-base/macros.h> 30#include <android-base/macros.h>
29 31
30namespace android { 32namespace android {
31namespace fuse { 33namespace fuse {
32 34namespace {
33static_assert(
34 std::is_standard_layout<FuseBuffer>::value,
35 "FuseBuffer must be standard layout union.");
36 35
37template <typename T> 36template <typename T>
38bool FuseMessage<T>::CheckHeaderLength(const char* name) const { 37bool CheckHeaderLength(const FuseMessage<T>* self, const char* name) {
39 const auto& header = static_cast<const T*>(this)->header; 38 const auto& header = static_cast<const T*>(self)->header;
40 if (header.len >= sizeof(header) && header.len <= sizeof(T)) { 39 if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
41 return true; 40 return true;
42 } else { 41 } else {
43 LOG(ERROR) << "Invalid header length is found in " << name << ": " << 42 LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len;
44 header.len; 43 return false;
45 return false; 44 }
46 }
47} 45}
48 46
49template <typename T> 47template <typename T>
50bool FuseMessage<T>::Read(int fd) { 48ResultOrAgain ReadInternal(FuseMessage<T>* self, int fd, int sockflag) {
51 char* const buf = reinterpret_cast<char*>(this); 49 char* const buf = reinterpret_cast<char*>(self);
52 const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T))); 50 const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag))
53 if (result < 0) { 51 : TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T)));
54 PLOG(ERROR) << "Failed to read a FUSE message"; 52
55 return false; 53 switch (result) {
56 } 54 case 0:
57 55 // Expected EOF.
58 const auto& header = static_cast<const T*>(this)->header; 56 return ResultOrAgain::kFailure;
59 if (result < static_cast<ssize_t>(sizeof(header))) { 57 case -1:
60 LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << 58 if (errno == EAGAIN) {
61 sizeof(header); 59 return ResultOrAgain::kAgain;
62 return false; 60 }
63 } 61 PLOG(ERROR) << "Failed to read a FUSE message";
62 return ResultOrAgain::kFailure;
63 }
64
65 const auto& header = static_cast<const T*>(self)->header;
66 if (result < static_cast<ssize_t>(sizeof(header))) {
67 LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header);
68 return ResultOrAgain::kFailure;
69 }
70
71 if (!CheckHeaderLength<T>(self, "Read")) {
72 return ResultOrAgain::kFailure;
73 }
74
75 if (static_cast<uint32_t>(result) != header.len) {
76 LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len;
77 return ResultOrAgain::kFailure;
78 }
79
80 return ResultOrAgain::kSuccess;
81}
64 82
65 if (!CheckHeaderLength("Read")) { 83template <typename T>
66 return false; 84ResultOrAgain WriteInternal(const FuseMessage<T>* self, int fd, int sockflag) {
67 } 85 if (!CheckHeaderLength<T>(self, "Write")) {
86 return ResultOrAgain::kFailure;
87 }
88
89 const char* const buf = reinterpret_cast<const char*>(self);
90 const auto& header = static_cast<const T*>(self)->header;
91 const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag))
92 : TEMP_FAILURE_RETRY(write(fd, buf, header.len));
93
94 if (result == -1) {
95 if (errno == EAGAIN) {
96 return ResultOrAgain::kAgain;
97 }
98 PLOG(ERROR) << "Failed to write a FUSE message";
99 return ResultOrAgain::kFailure;
100 }
101
102 CHECK(static_cast<uint32_t>(result) == header.len);
103 return ResultOrAgain::kSuccess;
104}
105}
68 106
69 if (static_cast<uint32_t>(result) > header.len) { 107static_assert(std::is_standard_layout<FuseBuffer>::value,
70 LOG(ERROR) << "Read bytes " << result << " are longer than header.len " << 108 "FuseBuffer must be standard layout union.");
71 header.len; 109
72 return false; 110bool SetupMessageSockets(base::unique_fd (*result)[2]) {
73 } 111 base::unique_fd fds[2];
112 {
113 int raw_fds[2];
114 if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) {
115 PLOG(ERROR) << "Failed to create sockets for proxy";
116 return false;
117 }
118 fds[0].reset(raw_fds[0]);
119 fds[1].reset(raw_fds[1]);
120 }
121
122 constexpr int kMaxMessageSize = sizeof(FuseBuffer);
123 if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 ||
124 setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) {
125 PLOG(ERROR) << "Failed to update buffer size for socket";
126 return false;
127 }
128
129 (*result)[0] = std::move(fds[0]);
130 (*result)[1] = std::move(fds[1]);
131 return true;
132}
74 133
75 if (!base::ReadFully(fd, buf + result, header.len - result)) { 134template <typename T>
76 PLOG(ERROR) << "ReadFully failed"; 135bool FuseMessage<T>::Read(int fd) {
77 return false; 136 return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess;
78 } 137}
79 138
80 return true; 139template <typename T>
140ResultOrAgain FuseMessage<T>::ReadOrAgain(int fd) {
141 return ReadInternal(this, fd, MSG_DONTWAIT);
81} 142}
82 143
83template <typename T> 144template <typename T>
84bool FuseMessage<T>::Write(int fd) const { 145bool FuseMessage<T>::Write(int fd) const {
85 if (!CheckHeaderLength("Write")) { 146 return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess;
86 return false; 147}
87 }
88
89 const char* const buf = reinterpret_cast<const char*>(this);
90 const auto& header = static_cast<const T*>(this)->header;
91 if (!base::WriteFully(fd, buf, header.len)) {
92 PLOG(ERROR) << "WriteFully failed";
93 return false;
94 }
95 148
96 return true; 149template <typename T>
150ResultOrAgain FuseMessage<T>::WriteOrAgain(int fd) const {
151 return WriteInternal(this, fd, MSG_DONTWAIT);
97} 152}
98 153
99template class FuseMessage<FuseRequest>; 154template class FuseMessage<FuseRequest>;
diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h
index 7abd2fa40..fbb05d633 100644
--- a/libappfuse/include/libappfuse/FuseBuffer.h
+++ b/libappfuse/include/libappfuse/FuseBuffer.h
@@ -17,6 +17,7 @@
17#ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_ 17#ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
18#define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_ 18#define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
19 19
20#include <android-base/unique_fd.h>
20#include <linux/fuse.h> 21#include <linux/fuse.h>
21 22
22namespace android { 23namespace android {
@@ -28,12 +29,24 @@ constexpr size_t kFuseMaxWrite = 256 * 1024;
28constexpr size_t kFuseMaxRead = 128 * 1024; 29constexpr size_t kFuseMaxRead = 128 * 1024;
29constexpr int32_t kFuseSuccess = 0; 30constexpr int32_t kFuseSuccess = 0;
30 31
32// Setup sockets to transfer FuseMessage.
33bool SetupMessageSockets(base::unique_fd (*sockets)[2]);
34
35enum class ResultOrAgain {
36 kSuccess,
37 kFailure,
38 kAgain,
39};
40
31template<typename T> 41template<typename T>
32class FuseMessage { 42class FuseMessage {
33 public: 43 public:
34 bool Read(int fd); 44 bool Read(int fd);
35 bool Write(int fd) const; 45 bool Write(int fd) const;
36 private: 46 ResultOrAgain ReadOrAgain(int fd);
47 ResultOrAgain WriteOrAgain(int fd) const;
48
49private:
37 bool CheckHeaderLength(const char* name) const; 50 bool CheckHeaderLength(const char* name) const;
38}; 51};
39 52
@@ -54,7 +67,7 @@ struct FuseRequest : public FuseMessage<FuseRequest> {
54 // for FUSE_READ 67 // for FUSE_READ
55 fuse_read_in read_in; 68 fuse_read_in read_in;
56 // for FUSE_LOOKUP 69 // for FUSE_LOOKUP
57 char lookup_name[0]; 70 char lookup_name[kFuseMaxWrite];
58 }; 71 };
59 void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique); 72 void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique);
60}; 73};
diff --git a/libappfuse/tests/FuseAppLoopTest.cc b/libappfuse/tests/FuseAppLoopTest.cc
index 25906cf1c..64dd81330 100644
--- a/libappfuse/tests/FuseAppLoopTest.cc
+++ b/libappfuse/tests/FuseAppLoopTest.cc
@@ -109,10 +109,7 @@ class FuseAppLoopTest : public ::testing::Test {
109 109
110 void SetUp() override { 110 void SetUp() override {
111 base::SetMinimumLogSeverity(base::VERBOSE); 111 base::SetMinimumLogSeverity(base::VERBOSE);
112 int sockets[2]; 112 ASSERT_TRUE(SetupMessageSockets(&sockets_));
113 ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, sockets));
114 sockets_[0].reset(sockets[0]);
115 sockets_[1].reset(sockets[1]);
116 thread_ = std::thread([this] { 113 thread_ = std::thread([this] {
117 StartFuseAppLoop(sockets_[1].release(), &callback_); 114 StartFuseAppLoop(sockets_[1].release(), &callback_);
118 }); 115 });
diff --git a/libappfuse/tests/FuseBridgeLoopTest.cc b/libappfuse/tests/FuseBridgeLoopTest.cc
index e74d9e700..b4c1efb01 100644
--- a/libappfuse/tests/FuseBridgeLoopTest.cc
+++ b/libappfuse/tests/FuseBridgeLoopTest.cc
@@ -50,15 +50,8 @@ class FuseBridgeLoopTest : public ::testing::Test {
50 50
51 void SetUp() override { 51 void SetUp() override {
52 base::SetMinimumLogSeverity(base::VERBOSE); 52 base::SetMinimumLogSeverity(base::VERBOSE);
53 int dev_sockets[2]; 53 ASSERT_TRUE(SetupMessageSockets(&dev_sockets_));
54 int proxy_sockets[2]; 54 ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_));
55 ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, dev_sockets));
56 ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, proxy_sockets));
57 dev_sockets_[0].reset(dev_sockets[0]);
58 dev_sockets_[1].reset(dev_sockets[1]);
59 proxy_sockets_[0].reset(proxy_sockets[0]);
60 proxy_sockets_[1].reset(proxy_sockets[1]);
61
62 thread_ = std::thread([this] { 55 thread_ = std::thread([this] {
63 StartFuseBridgeLoop( 56 StartFuseBridgeLoop(
64 dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_); 57 dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_);
diff --git a/libappfuse/tests/FuseBufferTest.cc b/libappfuse/tests/FuseBufferTest.cc
index 1a1abd57e..ade34acc1 100644
--- a/libappfuse/tests/FuseBufferTest.cc
+++ b/libappfuse/tests/FuseBufferTest.cc
@@ -112,30 +112,6 @@ TEST(FuseMessageTest, Write_TooShort) {
112 TestWriteInvalidLength(sizeof(fuse_in_header) - 1); 112 TestWriteInvalidLength(sizeof(fuse_in_header) - 1);
113} 113}
114 114
115TEST(FuseMessageTest, ShortWriteAndRead) {
116 int raw_fds[2];
117 ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds));
118
119 android::base::unique_fd fds[2];
120 fds[0].reset(raw_fds[0]);
121 fds[1].reset(raw_fds[1]);
122
123 const int send_buffer_size = 1024;
124 ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size,
125 sizeof(int)));
126
127 bool succeed = false;
128 const int sender_fd = fds[0].get();
129 std::thread thread([sender_fd, &succeed] {
130 FuseRequest request;
131 request.header.len = 1024 * 4;
132 succeed = request.Write(sender_fd);
133 });
134 thread.detach();
135 FuseRequest request;
136 ASSERT_TRUE(request.Read(fds[1]));
137}
138
139TEST(FuseResponseTest, Reset) { 115TEST(FuseResponseTest, Reset) {
140 FuseResponse response; 116 FuseResponse response;
141 // Write 1 to the first ten bytes. 117 // Write 1 to the first ten bytes.
@@ -211,5 +187,29 @@ TEST(FuseBufferTest, HandleNotImpl) {
211 EXPECT_EQ(-ENOSYS, buffer.response.header.error); 187 EXPECT_EQ(-ENOSYS, buffer.response.header.error);
212} 188}
213 189
190TEST(SetupMessageSocketsTest, Stress) {
191 constexpr int kCount = 1000;
192
193 FuseRequest request;
194 request.header.len = sizeof(FuseRequest);
195
196 base::unique_fd fds[2];
197 SetupMessageSockets(&fds);
198
199 std::thread thread([&fds] {
200 FuseRequest request;
201 for (int i = 0; i < kCount; ++i) {
202 ASSERT_TRUE(request.Read(fds[1]));
203 usleep(1000);
204 }
205 });
206
207 for (int i = 0; i < kCount; ++i) {
208 ASSERT_TRUE(request.Write(fds[0]));
209 }
210
211 thread.join();
212}
213
214} // namespace fuse 214} // namespace fuse
215} // namespace android 215} // namespace android