Skip to main content

zlink_core/connection/
write_connection.rs

1//! Contains connection related API.
2
3use core::fmt::Debug;
4
5#[cfg(feature = "std")]
6use alloc::collections::VecDeque;
7use alloc::vec::Vec;
8use serde::Serialize;
9
10use super::{BUFFER_SIZE, Call, Reply, socket::WriteHalf};
11
12#[cfg(feature = "std")]
13use std::os::fd::OwnedFd;
14
15/// A connection.
16///
17/// The low-level API to send messages.
18///
19/// # Cancel safety
20///
21/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
22/// documentation.
23#[derive(Debug)]
24pub struct WriteConnection<Write: WriteHalf> {
25    pub(super) socket: Write,
26    pub(super) buffer: Vec<u8>,
27    pub(super) pos: usize,
28    id: usize,
29    #[cfg(feature = "std")]
30    pending_fds: VecDeque<MessageFds>,
31    // On macOS, SCM_RIGHTS does not properly hold a reference to the underlying file description
32    // when the sender closes the FD in the same process before the receiver calls `recvmsg`. The
33    // receiver ends up with a stale FD. We work around this by keeping sent FDs alive until the
34    // read half confirms it has called `recvmsg` for them via `drain_held_fds`.
35    #[cfg(all(feature = "std", target_os = "macos"))]
36    pub(super) held_fds: VecDeque<Vec<OwnedFd>>,
37}
38
39impl<Write: WriteHalf> WriteConnection<Write> {
40    /// Create a new connection.
41    pub(super) fn new(socket: Write, id: usize) -> Self {
42        Self {
43            socket,
44            id,
45            buffer: alloc::vec![0; BUFFER_SIZE],
46            pos: 0,
47            #[cfg(feature = "std")]
48            pending_fds: VecDeque::new(),
49            #[cfg(all(feature = "std", target_os = "macos"))]
50            held_fds: VecDeque::new(),
51        }
52    }
53
54    /// The unique identifier of the connection.
55    #[inline]
56    pub fn id(&self) -> usize {
57        self.id
58    }
59
60    /// Sends a method call.
61    ///
62    /// The generic `Method` is the type of the method name and its input parameters. This should be
63    /// a type that can serialize itself to a complete method call message, i-e an object containing
64    /// `method` and `parameter` fields. This can be easily achieved using the `serde::Serialize`
65    /// derive:
66    ///
67    /// ```rust
68    /// use serde::{Serialize, Deserialize};
69    /// use serde_prefix_all::prefix_all;
70    ///
71    /// #[prefix_all("org.example.ftl.")]
72    /// #[derive(Debug, Serialize, Deserialize)]
73    /// #[serde(tag = "method", content = "parameters")]
74    /// enum MyMethods<'m> {
75    ///    // The name needs to be the fully-qualified name of the error.
76    ///    Alpha { param1: u32, param2: &'m str},
77    ///    Bravo,
78    ///    Charlie { param1: &'m str },
79    /// }
80    /// ```
81    ///
82    /// The `fds` parameter contains file descriptors to send along with the call.
83    pub async fn send_call<Method>(
84        &mut self,
85        call: &Call<Method>,
86        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
87    ) -> crate::Result<()>
88    where
89        Method: Serialize + Debug,
90    {
91        trace!("connection {}: sending call: {:?}", self.id, call);
92        #[cfg(feature = "std")]
93        {
94            self.write(call, fds).await
95        }
96        #[cfg(not(feature = "std"))]
97        {
98            self.write(call).await
99        }
100    }
101
102    /// Send a reply over the socket.
103    ///
104    /// The generic parameter `Params` is the type of the successful reply. This should be a type
105    /// that can serialize itself as the `parameters` field of the reply.
106    ///
107    /// The `fds` parameter contains file descriptors to send along with the reply.
108    pub async fn send_reply<Params>(
109        &mut self,
110        reply: &Reply<Params>,
111        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
112    ) -> crate::Result<()>
113    where
114        Params: Serialize + Debug,
115    {
116        trace!("connection {}: sending reply: {:?}", self.id, reply);
117        #[cfg(feature = "std")]
118        {
119            self.write(reply, fds).await
120        }
121        #[cfg(not(feature = "std"))]
122        {
123            self.write(reply).await
124        }
125    }
126
127    /// Send an error reply over the socket.
128    ///
129    /// The generic parameter `ReplyError` is the type of the error reply. This should be a type
130    /// that can serialize itself to the whole reply object, containing `error` and `parameter`
131    /// fields. This can be easily achieved using the `serde::Serialize` derive (See the code
132    /// snippet in [`super::ReadConnection::receive_reply`] documentation for an example).
133    ///
134    /// The `fds` parameter contains file descriptors to send along with the error.
135    pub async fn send_error<ReplyError>(
136        &mut self,
137        error: &ReplyError,
138        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
139    ) -> crate::Result<()>
140    where
141        ReplyError: Serialize + Debug,
142    {
143        trace!("connection {}: sending error: {:?}", self.id, error);
144        #[cfg(feature = "std")]
145        {
146            self.write(error, fds).await
147        }
148        #[cfg(not(feature = "std"))]
149        {
150            self.write(error).await
151        }
152    }
153
154    /// Enqueue a call to be sent over the socket.
155    ///
156    /// Similar to [`WriteConnection::send_call`], except that the call is not sent immediately but
157    /// enqueued for later sending. This is useful when you want to send multiple calls in a
158    /// batch.
159    ///
160    /// The `fds` parameter contains file descriptors to send along with the call.
161    pub fn enqueue_call<Method>(
162        &mut self,
163        call: &Call<Method>,
164        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
165    ) -> crate::Result<()>
166    where
167        Method: Serialize + Debug,
168    {
169        trace!("connection {}: enqueuing call: {:?}", self.id, call);
170        #[cfg(feature = "std")]
171        {
172            self.enqueue(call, fds)
173        }
174        #[cfg(not(feature = "std"))]
175        {
176            self.enqueue(call)
177        }
178    }
179
180    /// Send out the enqueued calls.
181    pub async fn flush(&mut self) -> crate::Result<()> {
182        if self.pos == 0 {
183            return Ok(());
184        }
185
186        #[allow(unused_mut)]
187        let mut sent_pos = 0;
188
189        #[cfg(feature = "std")]
190        {
191            // While there are FDs, send one message at a time.
192            while !self.pending_fds.is_empty() {
193                // Get the first FD entry.
194                let pending = self.pending_fds.front().unwrap();
195                let fd_offset = pending.offset;
196                let msg_len = pending.len;
197
198                // If there are bytes before the FD message, send them first without FDs.
199                if sent_pos < fd_offset {
200                    trace!(
201                        "connection {}: flushing {} bytes before FD message",
202                        self.id,
203                        fd_offset - sent_pos
204                    );
205                    self.socket
206                        .write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
207                        .await?;
208                }
209
210                // Send this message with its FDs.
211                let msg_end = fd_offset + msg_len;
212                let MessageFds {
213                    fds,
214                    offset: _,
215                    len: _,
216                } = self.pending_fds.pop_front().unwrap();
217                trace!(
218                    "connection {}: flushing {} bytes with {} FDs",
219                    self.id,
220                    msg_len,
221                    fds.len()
222                );
223                self.socket
224                    .write(&self.buffer[fd_offset..msg_end], &fds)
225                    .await?;
226                sent_pos = msg_end;
227
228                // On macOS, keep sent FDs alive until the read half confirms receipt via
229                // `recvmsg`. See the comment on the `held_fds` field for details.
230                #[cfg(target_os = "macos")]
231                self.held_fds.push_back(fds);
232            }
233        }
234
235        // No more FDs, send all remaining bytes at once.
236        if sent_pos < self.pos {
237            trace!(
238                "connection {}: flushing {} bytes",
239                self.id,
240                self.pos - sent_pos
241            );
242            #[cfg(feature = "std")]
243            {
244                self.socket
245                    .write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
246                    .await?;
247            }
248            #[cfg(not(feature = "std"))]
249            {
250                self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
251            }
252        }
253
254        self.pos = 0;
255        Ok(())
256    }
257
258    /// The underlying write half of the socket.
259    pub fn write_half(&self) -> &Write {
260        &self.socket
261    }
262
263    pub(super) async fn write<T>(
264        &mut self,
265        value: &T,
266        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
267    ) -> crate::Result<()>
268    where
269        T: Serialize + ?Sized + Debug,
270    {
271        #[cfg(feature = "std")]
272        {
273            self.enqueue(value, fds)?;
274        }
275        #[cfg(not(feature = "std"))]
276        {
277            self.enqueue(value)?;
278        }
279        self.flush().await
280    }
281
282    pub(super) fn enqueue<T>(
283        &mut self,
284        value: &T,
285        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
286    ) -> crate::Result<()>
287    where
288        T: Serialize + ?Sized + Debug,
289    {
290        #[cfg(feature = "std")]
291        let start_pos = self.pos;
292
293        let len = loop {
294            match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
295                Ok(len) => break len,
296                Err(crate::json_ser::Error::BufferTooSmall) => {
297                    // Buffer too small, grow it and retry
298                    self.grow_buffer()?;
299                }
300                Err(crate::json_ser::Error::KeyMustBeAString) => {
301                    // Actual serialization error
302                    // Convert to serde_json::Error for public API
303                    return Err(crate::Error::Json(serde::ser::Error::custom(
304                        "key must be a string",
305                    )));
306                }
307            }
308        };
309
310        // Add null terminator after this message.
311        if self.pos + len == self.buffer.len() {
312            self.grow_buffer()?;
313        }
314        self.buffer[self.pos + len] = b'\0';
315        self.pos += len + 1;
316
317        // Store FDs with message offset and length.
318        #[cfg(feature = "std")]
319        if !fds.is_empty() {
320            self.pending_fds.push_back(MessageFds {
321                offset: start_pos,
322                len: len + 1, // Include null terminator.
323                fds,
324            });
325        }
326
327        Ok(())
328    }
329
330    fn grow_buffer(&mut self) -> crate::Result<()> {
331        if self.buffer.len() >= super::MAX_BUFFER_SIZE {
332            return Err(crate::Error::BufferOverflow);
333        }
334
335        self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
336
337        Ok(())
338    }
339}
340
341/// Information about file descriptors pending to be sent with a message.
342#[cfg(feature = "std")]
343#[derive(Debug)]
344struct MessageFds {
345    /// File descriptors to send with this message.
346    fds: Vec<OwnedFd>,
347    /// Buffer offset where the message starts.
348    offset: usize,
349    /// Length of the message including the null terminator.
350    len: usize,
351}