zlink_core/connection/
read_connection.rs

1//! Contains connection related API.
2
3use core::{fmt::Debug, str::from_utf8_unchecked};
4
5use crate::{varlink_service, Result};
6
7use super::{
8    reply::{self, Reply},
9    socket::ReadHalf,
10    Call, BUFFER_SIZE, MAX_BUFFER_SIZE,
11};
12#[cfg(feature = "std")]
13use alloc::collections::VecDeque;
14use alloc::vec::Vec;
15use serde::Deserialize;
16use serde_json::Deserializer;
17
18#[cfg(feature = "std")]
19use std::os::fd::OwnedFd;
20
21// Type alias for receive methods - std returns FDs, no_std doesn't
22#[cfg(feature = "std")]
23type RecvResult<T> = (T, Vec<OwnedFd>);
24#[cfg(not(feature = "std"))]
25type RecvResult<T> = T;
26
27/// A connection that can only be used for reading.
28///
29/// # Cancel safety
30///
31/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
32/// documentation.
33#[derive(Debug)]
34pub struct ReadConnection<Read: ReadHalf> {
35    socket: Read,
36    read_pos: usize,
37    msg_pos: usize,
38    buffer: Vec<u8>,
39    id: usize,
40    #[cfg(feature = "std")]
41    pending_fds: VecDeque<Vec<OwnedFd>>,
42}
43
44impl<Read: ReadHalf> ReadConnection<Read> {
45    /// Create a new connection.
46    pub(super) fn new(socket: Read, id: usize) -> Self {
47        Self {
48            socket,
49            read_pos: 0,
50            msg_pos: 0,
51            id,
52            buffer: alloc::vec![0; BUFFER_SIZE],
53            #[cfg(feature = "std")]
54            pending_fds: VecDeque::new(),
55        }
56    }
57
58    /// The unique identifier of the connection.
59    #[inline]
60    pub fn id(&self) -> usize {
61        self.id
62    }
63
64    /// Receives a method call reply.
65    ///
66    /// The generic parameters needs some explanation:
67    ///
68    /// * `ReplyParams` is the type of the successful reply. This should be a type that can
69    ///   deserialize itself from the `parameters` field of the reply.
70    /// * `ReplyError` is the type of the error reply. This should be a type that can deserialize
71    ///   itself from the whole reply object itself and must fail when there is no `error` field in
72    ///   the object. This can be easily achieved using the `zlink::ReplyError` derive:
73    ///
74    /// ```rust
75    /// use zlink_core::ReplyError;
76    ///
77    /// #[derive(Debug, ReplyError)]
78    /// #[zlink(
79    ///     interface = "org.example.ftl",
80    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
81    ///     crate = "zlink_core",
82    /// )]
83    /// enum MyError {
84    ///     Alpha { param1: u32, param2: String },
85    ///     Bravo,
86    ///     Charlie { param1: String },
87    /// }
88    /// ```
89    ///
90    /// Returns the reply and any file descriptors received (std only).
91    pub async fn receive_reply<'r, ReplyParams, ReplyError>(
92        &'r mut self,
93    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
94    where
95        ReplyParams: Deserialize<'r> + Debug,
96        ReplyError: Deserialize<'r> + Debug,
97    {
98        #[derive(Debug, Deserialize)]
99        #[serde(untagged)]
100        enum ReplyMsg<'m, ReplyParams, ReplyError> {
101            #[serde(borrow)]
102            Varlink(varlink_service::Error<'m>),
103            Error(ReplyError),
104            Reply(Reply<ReplyParams>),
105        }
106
107        let recv_result = self
108            .read_message::<ReplyMsg<'_, ReplyParams, ReplyError>>()
109            .await?;
110
111        #[cfg(feature = "std")]
112        let (msg, fds) = recv_result;
113        #[cfg(not(feature = "std"))]
114        let msg = recv_result;
115
116        let result = match msg {
117            // Varlink service interface error need to be returned as the top-level error.
118            ReplyMsg::Varlink(e) => Err(crate::Error::VarlinkService(e.into_owned())),
119            ReplyMsg::Error(e) => Ok(Err(e)),
120            ReplyMsg::Reply(reply) => Ok(Ok(reply)),
121        };
122
123        #[cfg(feature = "std")]
124        return result.map(|r| (r, fds));
125        #[cfg(not(feature = "std"))]
126        return result;
127    }
128
129    /// Receive a method call over the socket.
130    ///
131    /// The generic `Method` is the type of the method name and its input parameters. This should be
132    /// a type that can deserialize itself from a complete method call message, i-e an object
133    /// containing `method` and `parameter` fields. This can be easily achieved using the
134    /// `serde::Deserialize` derive (See the code snippet in [`super::WriteConnection::send_call`]
135    /// documentation for an example).
136    ///
137    /// Returns the call and any file descriptors received (std only).
138    pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
139    where
140        Method: Deserialize<'m> + Debug,
141    {
142        self.read_message::<Call<Method>>().await
143    }
144
145    // Reads at least one full message from the socket and return a single message bytes.
146    async fn read_message<'m, M>(&'m mut self) -> Result<RecvResult<M>>
147    where
148        M: Deserialize<'m> + Debug,
149    {
150        self.read_from_socket().await?;
151
152        let mut stream = Deserializer::from_slice(&self.buffer[self.msg_pos..]).into_iter::<M>();
153        let msg = stream.next();
154        let null_index = self.msg_pos + stream.byte_offset();
155        let buffer = &self.buffer[self.msg_pos..null_index];
156        if self.buffer[null_index + 1] == b'\0' {
157            // This means we're reading the last message and can now reset the indices.
158            self.read_pos = 0;
159            self.msg_pos = 0;
160        } else {
161            self.msg_pos = null_index + 1;
162        }
163
164        match msg {
165            Some(Ok(msg)) => {
166                // SAFETY: Since the parsing from JSON already succeeded, we can be sure that the
167                // buffer contains a valid UTF-8 string.
168                trace!("connection {}: received a message: {}", self.id, unsafe {
169                    from_utf8_unchecked(buffer)
170                });
171
172                #[cfg(feature = "std")]
173                {
174                    let fds = self.pending_fds.pop_front().unwrap_or_default();
175                    Ok((msg, fds))
176                }
177                #[cfg(not(feature = "std"))]
178                Ok(msg)
179            }
180            Some(Err(e)) => Err(e.into()),
181            None => Err(crate::Error::UnexpectedEof),
182        }
183    }
184
185    // Reads at least one full message from the socket.
186    async fn read_from_socket(&mut self) -> Result<()> {
187        if self.msg_pos > 0 {
188            // This means we already have at least one message in the buffer so no need to read.
189            return Ok(());
190        }
191
192        loop {
193            #[cfg(feature = "std")]
194            let (bytes_read, fds) = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
195            #[cfg(not(feature = "std"))]
196            let bytes_read = self.socket.read(&mut self.buffer[self.read_pos..]).await?;
197
198            if bytes_read == 0 {
199                return Err(crate::Error::UnexpectedEof);
200            }
201            self.read_pos += bytes_read;
202            #[cfg(feature = "std")]
203            if !fds.is_empty() {
204                self.pending_fds.push_back(fds);
205            }
206
207            if self.read_pos == self.buffer.len() {
208                if self.read_pos >= MAX_BUFFER_SIZE {
209                    return Err(crate::Error::BufferOverflow);
210                }
211
212                self.buffer.extend(core::iter::repeat_n(0, BUFFER_SIZE));
213            }
214
215            // This marks end of all messages. After this loop is finished, we'll have 2 consecutive
216            // null bytes at the end. This is then used by the callers to determine that they've
217            // read all messages and can now reset the `read_pos`.
218            self.buffer[self.read_pos] = b'\0';
219
220            if self.buffer[self.read_pos - 1] == b'\0' {
221                // One or more full messages were read.
222                break;
223            }
224        }
225
226        Ok(())
227    }
228
229    /// The underlying read half of the socket.
230    pub fn read_half(&self) -> &Read {
231        &self.socket
232    }
233}