zbus/connection/socket/
mod.rs

1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9#[cfg(unix)]
10pub(crate) mod command;
11#[cfg(unix)]
12pub(crate) use command::Command;
13mod tcp;
14mod unix;
15mod vsock;
16
17#[cfg(not(feature = "tokio"))]
18use async_io::Async;
19#[cfg(not(feature = "tokio"))]
20use std::sync::Arc;
21use std::{io, mem};
22use tracing::trace;
23
24use crate::{
25    Message,
26    conn::AuthMechanism,
27    fdo::ConnectionCredentials,
28    message::{
29        PrimaryHeader,
30        header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
31    },
32    padding_for_8_bytes,
33};
34#[cfg(unix)]
35use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
36use zvariant::{
37    Endian,
38    serialized::{self, Context},
39};
40
41#[cfg(unix)]
42type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
43
44#[cfg(not(unix))]
45type RecvmsgResult = io::Result<usize>;
46
47/// Trait representing some transport layer over which the DBus protocol can be used.
48///
49/// In order to allow simultaneous reading and writing, this trait requires you to split the socket
50/// into a read half and a write half. The reader and writer halves can be any types that implement
51/// [`ReadHalf`] and [`WriteHalf`] respectively.
52///
53/// The crate provides implementations for `async_io` and `tokio`'s `UnixStream` wrappers if you
54/// enable the corresponding crate features (`async_io` is enabled by default).
55///
56/// You can implement it manually to integrate with other runtimes or other dbus transports.  Feel
57/// free to submit pull requests to add support for more runtimes to zbus itself so rust's orphan
58/// rules don't force the use of a wrapper struct (and to avoid duplicating the work across many
59/// projects).
60pub trait Socket {
61    type ReadHalf: ReadHalf;
62    type WriteHalf: WriteHalf;
63
64    /// Split the socket into a read half and a write half.
65    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
66    where
67        Self: Sized;
68}
69
70/// The read half of a socket.
71///
72/// See [`Socket`] for more details.
73#[async_trait::async_trait]
74pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
75    /// Receive a message on the socket.
76    ///
77    /// This is the higher-level method to receive a full D-Bus message.
78    ///
79    /// The default implementation uses `recvmsg` to receive the message. Implementers should
80    /// override either this or `recvmsg`. Note that if you override this method, zbus will not be
81    /// able perform an authentication handshake and hence will skip the handshake. Therefore your
82    /// implementation will only be useful for pre-authenticated connections or connections that do
83    /// not require authentication.
84    ///
85    /// # Parameters
86    ///
87    /// - `seq`: The sequence number of the message. The returned message should have this sequence.
88    /// - `already_received_bytes`: Sometimes, zbus already received some bytes from the socket
89    ///   belonging to the first message(s) (as part of the connection handshake process). This is
90    ///   the buffer containing those bytes (if any). If you're implementing this method, most
91    ///   likely you can safely ignore this parameter.
92    /// - `already_received_fds`: Same goes for file descriptors belonging to first messages.
93    async fn receive_message(
94        &mut self,
95        seq: u64,
96        already_received_bytes: &mut Vec<u8>,
97        #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
98    ) -> crate::Result<Message> {
99        #[cfg(unix)]
100        let mut fds = vec![];
101        let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
102            let mut bytes = vec![];
103            if !already_received_bytes.is_empty() {
104                mem::swap(already_received_bytes, &mut bytes);
105            }
106            let mut pos = bytes.len();
107            bytes.resize(MIN_MESSAGE_SIZE, 0);
108            // We don't have enough data to make a proper message header yet.
109            // Some partial read may be in raw_in_buffer, so we try to complete it
110            // until we have MIN_MESSAGE_SIZE bytes
111            //
112            // Given that MIN_MESSAGE_SIZE is 16, this codepath is actually extremely unlikely
113            // to be taken more than once
114            while pos < MIN_MESSAGE_SIZE {
115                let res = self.recvmsg(&mut bytes[pos..]).await?;
116                let len = {
117                    #[cfg(unix)]
118                    {
119                        fds.extend(res.1);
120                        res.0
121                    }
122                    #[cfg(not(unix))]
123                    {
124                        res
125                    }
126                };
127                pos += len;
128                if len == 0 {
129                    return Err(std::io::Error::new(
130                        std::io::ErrorKind::UnexpectedEof,
131                        "failed to receive message",
132                    )
133                    .into());
134                }
135            }
136
137            bytes
138        } else {
139            already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
140        };
141
142        let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
143        let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
144        let body_padding = padding_for_8_bytes(header_len);
145        let body_len = primary_header.body_len() as usize;
146        let total_len = header_len + body_padding + body_len;
147        if total_len > MAX_MESSAGE_SIZE {
148            return Err(crate::Error::ExcessData);
149        }
150
151        // By this point we have a full primary header, so we know the exact length of the complete
152        // message.
153        if !already_received_bytes.is_empty() {
154            // still have some bytes buffered.
155            let pending = total_len - bytes.len();
156            let to_take = std::cmp::min(pending, already_received_bytes.len());
157            bytes.extend(already_received_bytes.drain(..to_take));
158        }
159        let mut pos = bytes.len();
160        bytes.resize(total_len, 0);
161
162        // Read the rest, if any
163        while pos < total_len {
164            let res = self.recvmsg(&mut bytes[pos..]).await?;
165            let read = {
166                #[cfg(unix)]
167                {
168                    fds.extend(res.1);
169                    res.0
170                }
171                #[cfg(not(unix))]
172                {
173                    res
174                }
175            };
176            pos += read;
177            if read == 0 {
178                return Err(crate::Error::InputOutput(
179                    std::io::Error::new(
180                        std::io::ErrorKind::UnexpectedEof,
181                        "failed to receive message",
182                    )
183                    .into(),
184                ));
185            }
186        }
187
188        // If we reach here, the message is complete; return it
189        let endian = Endian::from(primary_header.endian_sig());
190
191        #[cfg(unix)]
192        if !already_received_fds.is_empty() {
193            use crate::message::header::PRIMARY_HEADER_SIZE;
194
195            let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
196            let encoded_fields =
197                serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
198            let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
199            let num_required_fds = match fields.unix_fds {
200                Some(num_fds) => num_fds as usize,
201                _ => 0,
202            };
203            let num_pending = num_required_fds
204                .checked_sub(fds.len())
205                .ok_or_else(|| crate::Error::ExcessData)?;
206            // If we had previously received FDs, `num_pending` has to be > 0
207            if num_pending == 0 {
208                return Err(crate::Error::MissingParameter("Missing file descriptors"));
209            }
210            // All previously received FDs must go first in the list.
211            let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
212            mem::swap(&mut already_received, &mut fds);
213            fds.extend(already_received);
214        }
215
216        let ctxt = Context::new_dbus(endian, 0);
217        #[cfg(unix)]
218        let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
219        #[cfg(not(unix))]
220        let bytes = serialized::Data::new(bytes, ctxt);
221        Message::from_raw_parts(bytes, seq)
222    }
223
224    /// Attempt to receive bytes from the socket.
225    ///
226    /// On success, returns the number of bytes read as well as a `Vec` containing
227    /// any associated file descriptors.
228    ///
229    /// The default implementation simply panics. Implementers must override either `read_message`
230    /// or this method.
231    async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
232        unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
233    }
234
235    /// Return whether passing file descriptors is supported.
236    ///
237    /// Default implementation returns `false`.
238    fn can_pass_unix_fd(&self) -> bool {
239        false
240    }
241
242    /// The peer credentials.
243    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
244        Ok(ConnectionCredentials::default())
245    }
246
247    /// The authentication mechanism to use for this socket on the target OS.
248    ///
249    /// Default is `AuthMechanism::External`.
250    fn auth_mechanism(&self) -> AuthMechanism {
251        AuthMechanism::External
252    }
253}
254
255/// The write half of a socket.
256///
257/// See [`Socket`] for more details.
258#[async_trait::async_trait]
259pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
260    /// Send a message on the socket.
261    ///
262    /// This is the higher-level method to send a full D-Bus message.
263    ///
264    /// The default implementation uses `sendmsg` to send the message. Implementers should override
265    /// either this or `sendmsg`.
266    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
267        let data = msg.data();
268        let serial = msg.primary_header().serial_num();
269
270        trace!("Sending message: {:?}", msg);
271        let mut pos = 0;
272        while pos < data.len() {
273            #[cfg(unix)]
274            let fds = if pos == 0 {
275                data.fds().iter().map(|f| f.as_fd()).collect()
276            } else {
277                vec![]
278            };
279            pos += self
280                .sendmsg(
281                    &data[pos..],
282                    #[cfg(unix)]
283                    &fds,
284                )
285                .await?;
286        }
287        trace!("Sent message with serial: {}", serial);
288
289        Ok(())
290    }
291
292    /// Attempt to send a message on the socket
293    ///
294    /// On success, return the number of bytes written. There may be a partial write, in
295    /// which case the caller is responsible for sending the remaining data by calling this
296    /// method again until everything is written or it returns an error of kind `WouldBlock`.
297    ///
298    /// If at least one byte has been written, then all the provided file descriptors will
299    /// have been sent as well, and should not be provided again in subsequent calls.
300    ///
301    /// If the underlying transport does not support transmitting file descriptors, this
302    /// will return `Err(ErrorKind::InvalidInput)`.
303    ///
304    /// The default implementation simply panics. Implementers must override either `send_message`
305    /// or this method.
306    async fn sendmsg(
307        &mut self,
308        _buffer: &[u8],
309        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
310    ) -> io::Result<usize> {
311        unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
312    }
313
314    /// The dbus daemon on `freebsd` and `dragonfly` currently requires sending the zero byte
315    /// as a separate message with SCM_CREDS, as part of the `EXTERNAL` authentication on unix
316    /// sockets. This method is used by the authentication machinery in zbus to send this
317    /// zero byte. Socket implementations based on unix sockets should implement this method.
318    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
319    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
320        Ok(None)
321    }
322
323    /// Close the socket.
324    ///
325    /// After this call, it is valid for all reading and writing operations to fail.
326    async fn close(&mut self) -> io::Result<()>;
327
328    /// Whether passing file descriptors is supported.
329    ///
330    /// Default implementation returns `false`.
331    fn can_pass_unix_fd(&self) -> bool {
332        false
333    }
334
335    /// The peer credentials.
336    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
337        Ok(ConnectionCredentials::default())
338    }
339}
340
341#[async_trait::async_trait]
342impl ReadHalf for Box<dyn ReadHalf> {
343    fn can_pass_unix_fd(&self) -> bool {
344        (**self).can_pass_unix_fd()
345    }
346
347    async fn receive_message(
348        &mut self,
349        seq: u64,
350        already_received_bytes: &mut Vec<u8>,
351        #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
352    ) -> crate::Result<Message> {
353        (**self)
354            .receive_message(
355                seq,
356                already_received_bytes,
357                #[cfg(unix)]
358                already_received_fds,
359            )
360            .await
361    }
362
363    async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
364        (**self).recvmsg(buf).await
365    }
366
367    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
368        (**self).peer_credentials().await
369    }
370
371    fn auth_mechanism(&self) -> AuthMechanism {
372        (**self).auth_mechanism()
373    }
374}
375
376#[async_trait::async_trait]
377impl WriteHalf for Box<dyn WriteHalf> {
378    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
379        (**self).send_message(msg).await
380    }
381
382    async fn sendmsg(
383        &mut self,
384        buffer: &[u8],
385        #[cfg(unix)] fds: &[BorrowedFd<'_>],
386    ) -> io::Result<usize> {
387        (**self)
388            .sendmsg(
389                buffer,
390                #[cfg(unix)]
391                fds,
392            )
393            .await
394    }
395
396    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
397    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
398        (**self).send_zero_byte().await
399    }
400
401    async fn close(&mut self) -> io::Result<()> {
402        (**self).close().await
403    }
404
405    fn can_pass_unix_fd(&self) -> bool {
406        (**self).can_pass_unix_fd()
407    }
408
409    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
410        (**self).peer_credentials().await
411    }
412}
413
414#[cfg(not(feature = "tokio"))]
415impl<T> Socket for Async<T>
416where
417    T: std::fmt::Debug + Send + Sync,
418    Arc<Async<T>>: ReadHalf + WriteHalf,
419{
420    type ReadHalf = Arc<Async<T>>;
421    type WriteHalf = Arc<Async<T>>;
422
423    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
424        let arc = Arc::new(self);
425
426        Split {
427            read: arc.clone(),
428            write: arc,
429        }
430    }
431}