1use std::borrow::Cow;
4use std::collections::VecDeque;
5use std::ffi::CString;
6use std::fmt;
7use std::io::{self, IoSlice, IoSliceMut};
8use std::num::NonZeroU32;
9use std::os::fd::{AsRawFd, OwnedFd, RawFd};
10
11use crate::ring_buffer::RingBuffer;
12use crate::{
13 ArgType, ArgValue, Fixed, IoMode, Message, MessageBuffersPool, MessageHeader, ObjectId,
14};
15
16mod unix;
17
18pub const BYTES_OUT_LEN: usize = 4096;
19pub const BYTES_IN_LEN: usize = BYTES_OUT_LEN * 2;
20pub const FDS_OUT_LEN: usize = 28;
21pub const FDS_IN_LEN: usize = FDS_OUT_LEN * 2;
22
23pub struct BufferedSocket<T> {
30 socket: T,
31 bytes_in: RingBuffer,
32 bytes_out: RingBuffer,
33 fds_in: VecDeque<OwnedFd>,
34 fds_out: VecDeque<OwnedFd>,
35}
36
37pub trait Transport {
39 fn pollable_fd(&self) -> RawFd;
40
41 fn send(&mut self, bytes: &[IoSlice], fds: &[OwnedFd], mode: IoMode) -> io::Result<usize>;
42
43 fn recv(
44 &mut self,
45 bytes: &mut [IoSliceMut],
46 fds: &mut VecDeque<OwnedFd>,
47 mode: IoMode,
48 ) -> io::Result<usize>;
49}
50
51impl<T: Transport> AsRawFd for BufferedSocket<T> {
52 fn as_raw_fd(&self) -> RawFd {
53 self.socket.pollable_fd()
54 }
55}
56
57impl<T: Transport> From<T> for BufferedSocket<T> {
58 fn from(socket: T) -> Self {
59 Self {
60 socket,
61 bytes_in: RingBuffer::new(BYTES_IN_LEN),
62 bytes_out: RingBuffer::new(BYTES_OUT_LEN),
63 fds_in: VecDeque::new(),
64 fds_out: VecDeque::new(),
65 }
66 }
67}
68
69pub struct SendMessageError {
71 pub msg: Message,
72 pub err: io::Error,
73}
74
75#[derive(Debug)]
77pub enum RecvMessageError {
78 Io(io::Error),
79 TooManyFds,
80 TooManyBytes,
81 UnexpectedNull,
82 NullInString,
83}
84
85impl std::error::Error for RecvMessageError {}
86
87impl fmt::Display for RecvMessageError {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Self::Io(error) => write!(f, "io: {error}"),
91 Self::TooManyFds => f.write_str("message has too many file descriptors"),
92 Self::TooManyBytes => f.write_str("message is too large"),
93 Self::UnexpectedNull => f.write_str("message contains unexpected null"),
94 Self::NullInString => f.write_str("message contains null byte in a string"),
95 }
96 }
97}
98
99#[derive(Debug)]
101pub enum PeekHeaderError {
102 Io(io::Error),
103 NullObject,
104}
105
106impl std::error::Error for PeekHeaderError {}
107
108impl fmt::Display for PeekHeaderError {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self {
111 Self::Io(error) => write!(f, "io: {error}"),
112 Self::NullObject => f.write_str("header has a null object id"),
113 }
114 }
115}
116
117impl<T: Transport> BufferedSocket<T> {
118 pub fn write_message(
127 &mut self,
128 msg: Message,
129 msg_pool: &mut MessageBuffersPool,
130 mode: IoMode,
131 ) -> Result<(), SendMessageError> {
132 let size = MessageHeader::SIZE + msg.args.iter().map(ArgValue::size).sum::<usize>();
134 let fds_cnt = msg
135 .args
136 .iter()
137 .filter(|arg| matches!(arg, ArgValue::Fd(_)))
138 .count();
139
140 assert!(size <= BYTES_OUT_LEN);
142 assert!(fds_cnt <= FDS_OUT_LEN);
143 if size > self.bytes_out.writable_len() || fds_cnt + self.fds_out.len() > FDS_OUT_LEN {
144 if let Err(err) = self.flush(mode) {
145 return Err(SendMessageError { msg, err });
146 }
147 }
148
149 self.bytes_out.write_uint(msg.header.object_id.0.get());
151 self.bytes_out
152 .write_uint(((size as u32) << 16) | msg.header.opcode as u32);
153
154 let mut msg = msg;
156 for arg in msg.args.drain(..) {
157 match arg {
158 ArgValue::Uint(x) => self.bytes_out.write_uint(x),
159 ArgValue::Int(x) | ArgValue::Fixed(Fixed(x)) => self.bytes_out.write_int(x),
160 ArgValue::Object(ObjectId(x))
161 | ArgValue::OptObject(Some(ObjectId(x)))
162 | ArgValue::NewId(ObjectId(x)) => self.bytes_out.write_uint(x.get()),
163 ArgValue::OptObject(None) | ArgValue::OptString(None) => {
164 self.bytes_out.write_uint(0)
165 }
166 ArgValue::AnyNewId(iface, version, id) => {
167 self.send_array(iface.to_bytes_with_nul());
168 self.bytes_out.write_uint(version);
169 self.bytes_out.write_uint(id.0.get());
170 }
171 ArgValue::String(string) | ArgValue::OptString(Some(string)) => {
172 self.send_array(string.to_bytes_with_nul())
173 }
174 ArgValue::Array(array) => self.send_array(&array),
175 ArgValue::Fd(fd) => self.fds_out.push_back(fd),
176 }
177 }
178 msg_pool.reuse_args(msg.args);
179 Ok(())
180 }
181
182 pub fn peek_message_header(&mut self, mode: IoMode) -> Result<MessageHeader, PeekHeaderError> {
186 while self.bytes_in.readable_len() < MessageHeader::SIZE {
187 self.fill_incoming_buf(mode).map_err(PeekHeaderError::Io)?;
188 }
189
190 let mut raw = [0; MessageHeader::SIZE];
191 self.bytes_in.peek_bytes(&mut raw);
192 let object_id = u32::from_ne_bytes(raw[0..4].try_into().unwrap());
193 let size_and_opcode = u32::from_ne_bytes(raw[4..8].try_into().unwrap());
194
195 Ok(MessageHeader {
196 object_id: ObjectId(NonZeroU32::new(object_id).ok_or(PeekHeaderError::NullObject)?),
197 size: ((size_and_opcode & 0xFFFF_0000) >> 16) as u16,
198 opcode: (size_and_opcode & 0x0000_FFFF) as u16,
199 })
200 }
201
202 pub fn recv_message(
207 &mut self,
208 header: MessageHeader,
209 signature: &[ArgType],
210 msg_pool: &mut MessageBuffersPool,
211 mode: IoMode,
212 ) -> Result<Message, RecvMessageError> {
213 let fds_cnt = signature
215 .iter()
216 .filter(|arg| matches!(arg, ArgType::Fd))
217 .count();
218 if header.size as usize > BYTES_IN_LEN {
219 return Err(RecvMessageError::TooManyBytes);
220 }
221 if fds_cnt > FDS_IN_LEN {
222 return Err(RecvMessageError::TooManyFds);
223 }
224 while header.size as usize > self.bytes_in.readable_len() || fds_cnt > self.fds_in.len() {
225 self.fill_incoming_buf(mode).map_err(RecvMessageError::Io)?;
226 }
227
228 self.bytes_in.move_tail(MessageHeader::SIZE);
230
231 let mut args = msg_pool.get_args();
232 for arg_type in signature {
233 args.push(match arg_type {
234 ArgType::Int => ArgValue::Int(self.bytes_in.read_int()),
235 ArgType::Uint => ArgValue::Uint(self.bytes_in.read_uint()),
236 ArgType::Fixed => ArgValue::Fixed(Fixed(self.bytes_in.read_int())),
237 ArgType::Object => ArgValue::Object(
238 self.bytes_in
239 .read_id()
240 .ok_or(RecvMessageError::UnexpectedNull)?,
241 ),
242 ArgType::OptObject => ArgValue::OptObject(self.bytes_in.read_id()),
243 ArgType::NewId(_interface) => ArgValue::NewId(
244 self.bytes_in
245 .read_id()
246 .ok_or(RecvMessageError::UnexpectedNull)?,
247 ),
248 ArgType::AnyNewId => ArgValue::AnyNewId(
249 Cow::Owned(self.recv_string()?),
250 self.bytes_in.read_uint(),
251 self.bytes_in
252 .read_id()
253 .ok_or(RecvMessageError::UnexpectedNull)?,
254 ),
255 ArgType::String => ArgValue::String(self.recv_string()?),
256 ArgType::OptString => ArgValue::OptString(match self.bytes_in.read_uint() {
257 0 => None,
258 len => Some(self.recv_string_with_len(len)?),
259 }),
260 ArgType::Array => ArgValue::Array(self.recv_array()),
261 ArgType::Fd => ArgValue::Fd(self.fds_in.pop_front().unwrap()),
262 });
263 }
264
265 Ok(Message { header, args })
266 }
267
268 pub fn flush(&mut self, mode: IoMode) -> io::Result<()> {
270 while !self.bytes_out.is_empty() {
271 let mut iov_buf = [IoSlice::new(&[]), IoSlice::new(&[])];
272 let iov = self.bytes_out.get_readable_iov(&mut iov_buf);
273
274 let sent = self
275 .socket
276 .send(iov, self.fds_out.make_contiguous(), mode)?;
277
278 self.bytes_out.move_tail(sent);
279 self.fds_out.clear();
280 }
281
282 Ok(())
283 }
284
285 #[must_use]
287 pub fn transport(&self) -> &T {
288 &self.socket
289 }
290
291 #[must_use]
293 pub fn transport_mut(&mut self) -> &mut T {
294 &mut self.socket
295 }
296
297 fn fill_incoming_buf(&mut self, mode: IoMode) -> io::Result<()> {
298 if self.bytes_in.is_full() {
299 return Ok(());
300 }
301
302 let mut iov_buf = [IoSliceMut::new(&mut []), IoSliceMut::new(&mut [])];
303 let iov = self.bytes_in.get_writeable_iov(&mut iov_buf);
304
305 let read = self.socket.recv(iov, &mut self.fds_in, mode)?;
306 self.bytes_in.move_head(read);
307
308 Ok(())
309 }
310
311 fn send_array(&mut self, array: &[u8]) {
312 let len = array.len() as u32;
313
314 self.bytes_out.write_uint(len);
315 self.bytes_out.write_bytes(array);
316
317 let padding = ((4 - (len % 4)) % 4) as usize;
318 self.bytes_out.write_bytes(&[0, 0, 0][..padding]);
319 }
320
321 fn recv_array(&mut self) -> Vec<u8> {
322 let len = self.bytes_in.read_uint() as usize;
323
324 let mut buf = vec![0; len];
325 self.bytes_in.read_bytes(&mut buf);
326
327 let padding = (4 - (len % 4)) % 4;
328 self.bytes_in.move_tail(padding);
329
330 buf
331 }
332
333 fn recv_string_with_len(&mut self, len: u32) -> Result<CString, RecvMessageError> {
334 let mut buf = vec![0; len as usize];
335 self.bytes_in.read_bytes(&mut buf);
336
337 let padding = (4 - (len % 4)) % 4;
338 self.bytes_in.move_tail(padding as usize);
339
340 CString::from_vec_with_nul(buf).map_err(|_| RecvMessageError::NullInString)
341 }
342
343 fn recv_string(&mut self) -> Result<CString, RecvMessageError> {
344 let len = self.bytes_in.read_uint();
345 if len == 0 {
346 Err(RecvMessageError::UnexpectedNull)
347 } else {
348 self.recv_string_with_len(len)
349 }
350 }
351}