Skip to main content

zlink_core/connection/
read_connection.rs

1//! Contains connection related API.
2
3use core::{fmt::Debug, str::from_utf8_unchecked};
4
5use crate::{Result, varlink_service};
6
7use super::{
8    BUFFER_SIZE, Call, MAX_BUFFER_SIZE,
9    reply::{self, Reply},
10    socket::ReadHalf,
11};
12#[cfg(feature = "std")]
13use alloc::collections::VecDeque;
14use alloc::vec::Vec;
15use serde::Deserialize;
16use serde_json::Deserializer;
17
18#[cfg(feature = "std")]
19use std::os::fd::OwnedFd;
20
21// Type alias for receive methods - std returns FDs, no_std doesn't
22#[cfg(feature = "std")]
23type RecvResult<T> = (T, Vec<OwnedFd>);
24#[cfg(not(feature = "std"))]
25type RecvResult<T> = T;
26
27/// A connection that can only be used for reading.
28///
29/// # Cancel safety
30///
31/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
32/// documentation.
33#[derive(Debug)]
34pub struct ReadConnection<Read: ReadHalf> {
35    socket: Read,
36    read_pos: usize,
37    msg_pos: usize,
38    buffer: Vec<u8>,
39    id: usize,
40    #[cfg(feature = "std")]
41    pending_fds: VecDeque<Vec<OwnedFd>>,
42    // Number of `recvmsg` calls that returned FDs. Used by `Connection` to drain
43    // `WriteConnection::held_fds` on macOS. See that field's comment for details.
44    #[cfg(all(feature = "std", target_os = "macos"))]
45    pub(super) fd_recvs: usize,
46}
47
48impl<Read: ReadHalf> ReadConnection<Read> {
49    /// Create a new connection.
50    pub(super) fn new(socket: Read, id: usize) -> Self {
51        Self {
52            socket,
53            read_pos: 0,
54            msg_pos: 0,
55            id,
56            buffer: alloc::vec![0; BUFFER_SIZE],
57            #[cfg(feature = "std")]
58            pending_fds: VecDeque::new(),
59            #[cfg(all(feature = "std", target_os = "macos"))]
60            fd_recvs: 0,
61        }
62    }
63
64    /// The unique identifier of the connection.
65    #[inline]
66    pub fn id(&self) -> usize {
67        self.id
68    }
69
70    /// Receives a method call reply.
71    ///
72    /// The generic parameters needs some explanation:
73    ///
74    /// * `ReplyParams` is the type of the successful reply. This should be a type that can
75    ///   deserialize itself from the `parameters` field of the reply.
76    /// * `ReplyError` is the type of the error reply. This should be a type that can deserialize
77    ///   itself from the whole reply object itself and must fail when there is no `error` field in
78    ///   the object. This can be easily achieved using the `zlink::ReplyError` derive:
79    ///
80    /// ```rust
81    /// use zlink_core::ReplyError;
82    ///
83    /// #[derive(Debug, ReplyError)]
84    /// #[zlink(
85    ///     interface = "org.example.ftl",
86    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
87    ///     crate = "zlink_core",
88    /// )]
89    /// enum MyError {
90    ///     Alpha { param1: u32, param2: String },
91    ///     Bravo,
92    ///     Charlie { param1: String },
93    /// }
94    /// ```
95    ///
96    /// Returns the reply and any file descriptors received (std only).
97    pub async fn receive_reply<'r, ReplyParams, ReplyError>(
98        &'r mut self,
99    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
100    where
101        ReplyParams: Deserialize<'r> + Debug,
102        ReplyError: Deserialize<'r> + Debug,
103    {
104        #[derive(Debug, Deserialize)]
105        #[serde(untagged)]
106        enum ReplyMsg<'m, ReplyParams, ReplyError> {
107            #[serde(borrow)]
108            Varlink(varlink_service::Error<'m>),
109            Error(ReplyError),
110            Reply(Reply<ReplyParams>),
111        }
112
113        let recv_result = self
114            .read_message::<ReplyMsg<'_, ReplyParams, ReplyError>>()
115            .await?;
116
117        #[cfg(feature = "std")]
118        let (msg, fds) = recv_result;
119        #[cfg(not(feature = "std"))]
120        let msg = recv_result;
121
122        let result = match msg {
123            // Varlink service interface error need to be returned as the top-level error.
124            ReplyMsg::Varlink(e) => Err(crate::Error::VarlinkService(e.into())),
125            ReplyMsg::Error(e) => Ok(Err(e)),
126            ReplyMsg::Reply(reply) => Ok(Ok(reply)),
127        };
128
129        #[cfg(feature = "std")]
130        return result.map(|r| (r, fds));
131        #[cfg(not(feature = "std"))]
132        return result;
133    }
134
135    /// Receive a method call over the socket.
136    ///
137    /// The generic `Method` is the type of the method name and its input parameters. This should be
138    /// a type that can deserialize itself from a complete method call message, i-e an object
139    /// containing `method` and `parameter` fields. This can be easily achieved using the
140    /// `serde::Deserialize` derive (See the code snippet in [`super::WriteConnection::send_call`]
141    /// documentation for an example).
142    ///
143    /// Returns the call and any file descriptors received (std only).
144    pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
145    where
146        Method: Deserialize<'m> + Debug,
147    {
148        self.read_message::<Call<Method>>().await
149    }
150
151    // Reads at least one full message from the socket and return a single message bytes.
152    async fn read_message<'m, M>(&'m mut self) -> Result<RecvResult<M>>
153    where
154        M: Deserialize<'m> + Debug,
155    {
156        self.read_from_socket().await?;
157
158        let mut stream = Deserializer::from_slice(&self.buffer[self.msg_pos..]).into_iter::<M>();
159        let msg = stream.next();
160        let null_index = self.msg_pos + stream.byte_offset();
161        let buffer = &self.buffer[self.msg_pos..null_index];
162        if self.buffer[null_index + 1] == b'\0' {
163            // This means we're reading the last message and can now reset the indices.
164            self.read_pos = 0;
165            self.msg_pos = 0;
166        } else {
167            self.msg_pos = null_index + 1;
168        }
169
170        match msg {
171            Some(Ok(msg)) => {
172                // SAFETY: Since the parsing from JSON already succeeded, we can be sure that the
173                // buffer contains a valid UTF-8 string.
174                trace!("connection {}: received a message: {}", self.id, unsafe {
175                    from_utf8_unchecked(buffer)
176                });
177
178                #[cfg(feature = "std")]
179                {
180                    let fds = self.pending_fds.pop_front().unwrap_or_default();
181                    Ok((msg, fds))
182                }
183                #[cfg(not(feature = "std"))]
184                Ok(msg)
185            }
186            Some(Err(e)) => Err(e.into()),
187            None => Err(crate::Error::UnexpectedEof),
188        }
189    }
190
191    // Reads at least one full message from the socket.
192    async fn read_from_socket(&mut self) -> Result<()> {
193        if self.msg_pos > 0 {
194            // This means we already have at least one message in the buffer so no need to read.
195            return Ok(());
196        }
197
198        loop {
199            #[cfg(feature = "std")]
200            let (bytes_read, fds) = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
201            #[cfg(not(feature = "std"))]
202            let bytes_read = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
203
204            if bytes_read == 0 {
205                return Err(crate::Error::UnexpectedEof);
206            }
207            self.read_pos += bytes_read;
208            #[cfg(feature = "std")]
209            if !fds.is_empty() {
210                self.pending_fds.push_back(fds);
211                // Track receipt so `Connection` can drain `WriteConnection::held_fds`.
212                #[cfg(target_os = "macos")]
213                {
214                    self.fd_recvs += 1;
215                }
216            }
217
218            if self.read_pos == self.buffer.len() {
219                if self.read_pos >= MAX_BUFFER_SIZE {
220                    return Err(crate::Error::BufferOverflow);
221                }
222
223                self.buffer.extend(core::iter::repeat_n(0, BUFFER_SIZE));
224            }
225
226            // This marks end of all messages. After this loop is finished, we'll have 2 consecutive
227            // null bytes at the end. This is then used by the callers to determine that they've
228            // read all messages and can now reset the `read_pos`.
229            self.buffer[self.read_pos] = b'\0';
230
231            if self.buffer[self.read_pos - 1] == b'\0' {
232                // One or more full messages were read.
233                break;
234            }
235        }
236
237        Ok(())
238    }
239
240    /// The underlying read half of the socket.
241    pub fn read_half(&self) -> &Read {
242        &self.socket
243    }
244}