wayrs_core/
transport.rs

1//! Wayland transport methods
2
3use std::borrow::Cow;
4use std::collections::VecDeque;
5use std::ffi::CString;
6use std::fmt;
7use std::io::{self, IoSlice, IoSliceMut};
8use std::num::NonZeroU32;
9use std::os::fd::{AsRawFd, OwnedFd, RawFd};
10
11use crate::ring_buffer::RingBuffer;
12use crate::{
13    ArgType, ArgValue, Fixed, IoMode, Message, MessageBuffersPool, MessageHeader, ObjectId,
14};
15
16mod unix;
17
18pub const BYTES_OUT_LEN: usize = 4096;
19pub const BYTES_IN_LEN: usize = BYTES_OUT_LEN * 2;
20pub const FDS_OUT_LEN: usize = 28;
21pub const FDS_IN_LEN: usize = FDS_OUT_LEN * 2;
22
23/// A buffered Wayland socket
24///
25/// Handles message marshalling and unmarshalling. This struct is generic over [`Transport`], which
26/// is usually [`UnixStream`](std::os::unix::net::UnixStream).
27///
28/// To create a new instance, use the `From<T: Transport>` implementation.
29pub struct BufferedSocket<T> {
30    socket: T,
31    bytes_in: RingBuffer,
32    bytes_out: RingBuffer,
33    fds_in: VecDeque<OwnedFd>,
34    fds_out: VecDeque<OwnedFd>,
35}
36
37/// An abstraction over Wayland transport methods
38pub trait Transport {
39    fn pollable_fd(&self) -> RawFd;
40
41    fn send(&mut self, bytes: &[IoSlice], fds: &[OwnedFd], mode: IoMode) -> io::Result<usize>;
42
43    fn recv(
44        &mut self,
45        bytes: &mut [IoSliceMut],
46        fds: &mut VecDeque<OwnedFd>,
47        mode: IoMode,
48    ) -> io::Result<usize>;
49}
50
51impl<T: Transport> AsRawFd for BufferedSocket<T> {
52    fn as_raw_fd(&self) -> RawFd {
53        self.socket.pollable_fd()
54    }
55}
56
57impl<T: Transport> From<T> for BufferedSocket<T> {
58    fn from(socket: T) -> Self {
59        Self {
60            socket,
61            bytes_in: RingBuffer::new(BYTES_IN_LEN),
62            bytes_out: RingBuffer::new(BYTES_OUT_LEN),
63            fds_in: VecDeque::new(),
64            fds_out: VecDeque::new(),
65        }
66    }
67}
68
69/// An error occurred while sending a message
70pub struct SendMessageError {
71    pub msg: Message,
72    pub err: io::Error,
73}
74
75/// An error occured while trying to receive a message
76#[derive(Debug)]
77pub enum RecvMessageError {
78    Io(io::Error),
79    TooManyFds,
80    TooManyBytes,
81    UnexpectedNull,
82    NullInString,
83}
84
85impl std::error::Error for RecvMessageError {}
86
87impl fmt::Display for RecvMessageError {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Self::Io(error) => write!(f, "io: {error}"),
91            Self::TooManyFds => f.write_str("message has too many file descriptors"),
92            Self::TooManyBytes => f.write_str("message is too large"),
93            Self::UnexpectedNull => f.write_str("message contains unexpected null"),
94            Self::NullInString => f.write_str("message contains null byte in a string"),
95        }
96    }
97}
98
99/// An error occured while trying to receive a message
100#[derive(Debug)]
101pub enum PeekHeaderError {
102    Io(io::Error),
103    NullObject,
104}
105
106impl std::error::Error for PeekHeaderError {}
107
108impl fmt::Display for PeekHeaderError {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self {
111            Self::Io(error) => write!(f, "io: {error}"),
112            Self::NullObject => f.write_str("header has a null object id"),
113        }
114    }
115}
116
117impl<T: Transport> BufferedSocket<T> {
118    /// Write a single Wayland message into the intevnal buffer.
119    ///
120    /// Flushes the buffer if neccessary. On failure, ownership of the message is returned.
121    ///
122    /// # Panics
123    ///
124    /// This function panics if the message size is larger than `BYTES_OUT_LEN` or it contains more
125    /// than `FDS_OUT_LEN` file descriptors.
126    pub fn write_message(
127        &mut self,
128        msg: Message,
129        msg_pool: &mut MessageBuffersPool,
130        mode: IoMode,
131    ) -> Result<(), SendMessageError> {
132        // Calc size
133        let size = MessageHeader::SIZE + msg.args.iter().map(ArgValue::size).sum::<usize>();
134        let fds_cnt = msg
135            .args
136            .iter()
137            .filter(|arg| matches!(arg, ArgValue::Fd(_)))
138            .count();
139
140        // Check size and flush if neccessary
141        assert!(size <= BYTES_OUT_LEN);
142        assert!(fds_cnt <= FDS_OUT_LEN);
143        if size > self.bytes_out.writable_len() || fds_cnt + self.fds_out.len() > FDS_OUT_LEN {
144            if let Err(err) = self.flush(mode) {
145                return Err(SendMessageError { msg, err });
146            }
147        }
148
149        // Header
150        self.bytes_out.write_uint(msg.header.object_id.0.get());
151        self.bytes_out
152            .write_uint(((size as u32) << 16) | msg.header.opcode as u32);
153
154        // Args
155        let mut msg = msg;
156        for arg in msg.args.drain(..) {
157            match arg {
158                ArgValue::Uint(x) => self.bytes_out.write_uint(x),
159                ArgValue::Int(x) | ArgValue::Fixed(Fixed(x)) => self.bytes_out.write_int(x),
160                ArgValue::Object(ObjectId(x))
161                | ArgValue::OptObject(Some(ObjectId(x)))
162                | ArgValue::NewId(ObjectId(x)) => self.bytes_out.write_uint(x.get()),
163                ArgValue::OptObject(None) | ArgValue::OptString(None) => {
164                    self.bytes_out.write_uint(0)
165                }
166                ArgValue::AnyNewId(iface, version, id) => {
167                    self.send_array(iface.to_bytes_with_nul());
168                    self.bytes_out.write_uint(version);
169                    self.bytes_out.write_uint(id.0.get());
170                }
171                ArgValue::String(string) | ArgValue::OptString(Some(string)) => {
172                    self.send_array(string.to_bytes_with_nul())
173                }
174                ArgValue::Array(array) => self.send_array(&array),
175                ArgValue::Fd(fd) => self.fds_out.push_back(fd),
176            }
177        }
178        msg_pool.reuse_args(msg.args);
179        Ok(())
180    }
181
182    /// Peek the next message header.
183    ///
184    /// Fills the internal buffer if needed and keeps the header in the buffer.
185    pub fn peek_message_header(&mut self, mode: IoMode) -> Result<MessageHeader, PeekHeaderError> {
186        while self.bytes_in.readable_len() < MessageHeader::SIZE {
187            self.fill_incoming_buf(mode).map_err(PeekHeaderError::Io)?;
188        }
189
190        let mut raw = [0; MessageHeader::SIZE];
191        self.bytes_in.peek_bytes(&mut raw);
192        let object_id = u32::from_ne_bytes(raw[0..4].try_into().unwrap());
193        let size_and_opcode = u32::from_ne_bytes(raw[4..8].try_into().unwrap());
194
195        Ok(MessageHeader {
196            object_id: ObjectId(NonZeroU32::new(object_id).ok_or(PeekHeaderError::NullObject)?),
197            size: ((size_and_opcode & 0xFFFF_0000) >> 16) as u16,
198            opcode: (size_and_opcode & 0x0000_FFFF) as u16,
199        })
200    }
201
202    /// Receive the entire next message.
203    ///
204    /// Fills the internal buffer if needed. `header` must be the value returned by
205    /// [`Self::peek_message_header`] right before calling this function.
206    pub fn recv_message(
207        &mut self,
208        header: MessageHeader,
209        signature: &[ArgType],
210        msg_pool: &mut MessageBuffersPool,
211        mode: IoMode,
212    ) -> Result<Message, RecvMessageError> {
213        // Check size and fill buffer if necessary
214        let fds_cnt = signature
215            .iter()
216            .filter(|arg| matches!(arg, ArgType::Fd))
217            .count();
218        if header.size as usize > BYTES_IN_LEN {
219            return Err(RecvMessageError::TooManyBytes);
220        }
221        if fds_cnt > FDS_IN_LEN {
222            return Err(RecvMessageError::TooManyFds);
223        }
224        while header.size as usize > self.bytes_in.readable_len() || fds_cnt > self.fds_in.len() {
225            self.fill_incoming_buf(mode).map_err(RecvMessageError::Io)?;
226        }
227
228        // Consume header
229        self.bytes_in.move_tail(MessageHeader::SIZE);
230
231        let mut args = msg_pool.get_args();
232        for arg_type in signature {
233            args.push(match arg_type {
234                ArgType::Int => ArgValue::Int(self.bytes_in.read_int()),
235                ArgType::Uint => ArgValue::Uint(self.bytes_in.read_uint()),
236                ArgType::Fixed => ArgValue::Fixed(Fixed(self.bytes_in.read_int())),
237                ArgType::Object => ArgValue::Object(
238                    self.bytes_in
239                        .read_id()
240                        .ok_or(RecvMessageError::UnexpectedNull)?,
241                ),
242                ArgType::OptObject => ArgValue::OptObject(self.bytes_in.read_id()),
243                ArgType::NewId(_interface) => ArgValue::NewId(
244                    self.bytes_in
245                        .read_id()
246                        .ok_or(RecvMessageError::UnexpectedNull)?,
247                ),
248                ArgType::AnyNewId => ArgValue::AnyNewId(
249                    Cow::Owned(self.recv_string()?),
250                    self.bytes_in.read_uint(),
251                    self.bytes_in
252                        .read_id()
253                        .ok_or(RecvMessageError::UnexpectedNull)?,
254                ),
255                ArgType::String => ArgValue::String(self.recv_string()?),
256                ArgType::OptString => ArgValue::OptString(match self.bytes_in.read_uint() {
257                    0 => None,
258                    len => Some(self.recv_string_with_len(len)?),
259                }),
260                ArgType::Array => ArgValue::Array(self.recv_array()),
261                ArgType::Fd => ArgValue::Fd(self.fds_in.pop_front().unwrap()),
262            });
263        }
264
265        Ok(Message { header, args })
266    }
267
268    /// Flush all pending messages.
269    pub fn flush(&mut self, mode: IoMode) -> io::Result<()> {
270        while !self.bytes_out.is_empty() {
271            let mut iov_buf = [IoSlice::new(&[]), IoSlice::new(&[])];
272            let iov = self.bytes_out.get_readable_iov(&mut iov_buf);
273
274            let sent = self
275                .socket
276                .send(iov, self.fds_out.make_contiguous(), mode)?;
277
278            self.bytes_out.move_tail(sent);
279            self.fds_out.clear();
280        }
281
282        Ok(())
283    }
284
285    /// Get a reference to the underlying transport.
286    #[must_use]
287    pub fn transport(&self) -> &T {
288        &self.socket
289    }
290
291    /// Get a mutable reference to the underlying transport.
292    #[must_use]
293    pub fn transport_mut(&mut self) -> &mut T {
294        &mut self.socket
295    }
296
297    fn fill_incoming_buf(&mut self, mode: IoMode) -> io::Result<()> {
298        if self.bytes_in.is_full() {
299            return Ok(());
300        }
301
302        let mut iov_buf = [IoSliceMut::new(&mut []), IoSliceMut::new(&mut [])];
303        let iov = self.bytes_in.get_writeable_iov(&mut iov_buf);
304
305        let read = self.socket.recv(iov, &mut self.fds_in, mode)?;
306        self.bytes_in.move_head(read);
307
308        Ok(())
309    }
310
311    fn send_array(&mut self, array: &[u8]) {
312        let len = array.len() as u32;
313
314        self.bytes_out.write_uint(len);
315        self.bytes_out.write_bytes(array);
316
317        let padding = ((4 - (len % 4)) % 4) as usize;
318        self.bytes_out.write_bytes(&[0, 0, 0][..padding]);
319    }
320
321    fn recv_array(&mut self) -> Vec<u8> {
322        let len = self.bytes_in.read_uint() as usize;
323
324        let mut buf = vec![0; len];
325        self.bytes_in.read_bytes(&mut buf);
326
327        let padding = (4 - (len % 4)) % 4;
328        self.bytes_in.move_tail(padding);
329
330        buf
331    }
332
333    fn recv_string_with_len(&mut self, len: u32) -> Result<CString, RecvMessageError> {
334        let mut buf = vec![0; len as usize];
335        self.bytes_in.read_bytes(&mut buf);
336
337        let padding = (4 - (len % 4)) % 4;
338        self.bytes_in.move_tail(padding as usize);
339
340        CString::from_vec_with_nul(buf).map_err(|_| RecvMessageError::NullInString)
341    }
342
343    fn recv_string(&mut self) -> Result<CString, RecvMessageError> {
344        let len = self.bytes_in.read_uint();
345        if len == 0 {
346            Err(RecvMessageError::UnexpectedNull)
347        } else {
348            self.recv_string_with_len(len)
349        }
350    }
351}