zlink_core/connection/
write_connection.rs

1//! Contains connection related API.
2
3use core::fmt::Debug;
4
5#[cfg(feature = "std")]
6use alloc::collections::VecDeque;
7use alloc::vec::Vec;
8use serde::Serialize;
9
10use super::{socket::WriteHalf, Call, Reply, BUFFER_SIZE};
11
12#[cfg(feature = "std")]
13use std::os::fd::OwnedFd;
14
15/// A connection.
16///
17/// The low-level API to send messages.
18///
19/// # Cancel safety
20///
21/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
22/// documentation.
23#[derive(Debug)]
24pub struct WriteConnection<Write: WriteHalf> {
25    pub(super) socket: Write,
26    pub(super) buffer: Vec<u8>,
27    pub(super) pos: usize,
28    id: usize,
29    #[cfg(feature = "std")]
30    pending_fds: VecDeque<MessageFds>,
31}
32
33impl<Write: WriteHalf> WriteConnection<Write> {
34    /// Create a new connection.
35    pub(super) fn new(socket: Write, id: usize) -> Self {
36        Self {
37            socket,
38            id,
39            buffer: alloc::vec![0; BUFFER_SIZE],
40            pos: 0,
41            #[cfg(feature = "std")]
42            pending_fds: VecDeque::new(),
43        }
44    }
45
46    /// The unique identifier of the connection.
47    #[inline]
48    pub fn id(&self) -> usize {
49        self.id
50    }
51
52    /// Sends a method call.
53    ///
54    /// The generic `Method` is the type of the method name and its input parameters. This should be
55    /// a type that can serialize itself to a complete method call message, i-e an object containing
56    /// `method` and `parameter` fields. This can be easily achieved using the `serde::Serialize`
57    /// derive:
58    ///
59    /// ```rust
60    /// use serde::{Serialize, Deserialize};
61    /// use serde_prefix_all::prefix_all;
62    ///
63    /// #[prefix_all("org.example.ftl.")]
64    /// #[derive(Debug, Serialize, Deserialize)]
65    /// #[serde(tag = "method", content = "parameters")]
66    /// enum MyMethods<'m> {
67    ///    // The name needs to be the fully-qualified name of the error.
68    ///    Alpha { param1: u32, param2: &'m str},
69    ///    Bravo,
70    ///    Charlie { param1: &'m str },
71    /// }
72    /// ```
73    ///
74    /// The `fds` parameter contains file descriptors to send along with the call.
75    pub async fn send_call<Method>(
76        &mut self,
77        call: &Call<Method>,
78        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
79    ) -> crate::Result<()>
80    where
81        Method: Serialize + Debug,
82    {
83        trace!("connection {}: sending call: {:?}", self.id, call);
84        #[cfg(feature = "std")]
85        {
86            self.write(call, fds).await
87        }
88        #[cfg(not(feature = "std"))]
89        {
90            self.write(call).await
91        }
92    }
93
94    /// Send a reply over the socket.
95    ///
96    /// The generic parameter `Params` is the type of the successful reply. This should be a type
97    /// that can serialize itself as the `parameters` field of the reply.
98    ///
99    /// The `fds` parameter contains file descriptors to send along with the reply.
100    pub async fn send_reply<Params>(
101        &mut self,
102        reply: &Reply<Params>,
103        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
104    ) -> crate::Result<()>
105    where
106        Params: Serialize + Debug,
107    {
108        trace!("connection {}: sending reply: {:?}", self.id, reply);
109        #[cfg(feature = "std")]
110        {
111            self.write(reply, fds).await
112        }
113        #[cfg(not(feature = "std"))]
114        {
115            self.write(reply).await
116        }
117    }
118
119    /// Send an error reply over the socket.
120    ///
121    /// The generic parameter `ReplyError` is the type of the error reply. This should be a type
122    /// that can serialize itself to the whole reply object, containing `error` and `parameter`
123    /// fields. This can be easily achieved using the `serde::Serialize` derive (See the code
124    /// snippet in [`super::ReadConnection::receive_reply`] documentation for an example).
125    ///
126    /// The `fds` parameter contains file descriptors to send along with the error.
127    pub async fn send_error<ReplyError>(
128        &mut self,
129        error: &ReplyError,
130        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
131    ) -> crate::Result<()>
132    where
133        ReplyError: Serialize + Debug,
134    {
135        trace!("connection {}: sending error: {:?}", self.id, error);
136        #[cfg(feature = "std")]
137        {
138            self.write(error, fds).await
139        }
140        #[cfg(not(feature = "std"))]
141        {
142            self.write(error).await
143        }
144    }
145
146    /// Enqueue a call to be sent over the socket.
147    ///
148    /// Similar to [`WriteConnection::send_call`], except that the call is not sent immediately but
149    /// enqueued for later sending. This is useful when you want to send multiple calls in a
150    /// batch.
151    ///
152    /// The `fds` parameter contains file descriptors to send along with the call.
153    pub fn enqueue_call<Method>(
154        &mut self,
155        call: &Call<Method>,
156        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
157    ) -> crate::Result<()>
158    where
159        Method: Serialize + Debug,
160    {
161        trace!("connection {}: enqueuing call: {:?}", self.id, call);
162        #[cfg(feature = "std")]
163        {
164            self.enqueue(call, fds)
165        }
166        #[cfg(not(feature = "std"))]
167        {
168            self.enqueue(call)
169        }
170    }
171
172    /// Send out the enqueued calls.
173    pub async fn flush(&mut self) -> crate::Result<()> {
174        if self.pos == 0 {
175            return Ok(());
176        }
177
178        #[allow(unused_mut)]
179        let mut sent_pos = 0;
180
181        #[cfg(feature = "std")]
182        {
183            // While there are FDs, send one message at a time.
184            while !self.pending_fds.is_empty() {
185                // Get the first FD entry.
186                let pending = self.pending_fds.front().unwrap();
187                let fd_offset = pending.offset;
188                let msg_len = pending.len;
189
190                // If there are bytes before the FD message, send them first without FDs.
191                if sent_pos < fd_offset {
192                    trace!(
193                        "connection {}: flushing {} bytes before FD message",
194                        self.id,
195                        fd_offset - sent_pos
196                    );
197                    self.socket
198                        .write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
199                        .await?;
200                }
201
202                // Send this message with its FDs.
203                let msg_end = fd_offset + msg_len;
204                let pending = self.pending_fds.pop_front().unwrap();
205                let fds = &pending.fds;
206                trace!(
207                    "connection {}: flushing {} bytes with {} FDs",
208                    self.id,
209                    msg_len,
210                    fds.len()
211                );
212                self.socket
213                    .write(&self.buffer[fd_offset..msg_end], fds)
214                    .await?;
215                sent_pos = msg_end;
216            }
217        }
218
219        // No more FDs, send all remaining bytes at once.
220        if sent_pos < self.pos {
221            trace!(
222                "connection {}: flushing {} bytes",
223                self.id,
224                self.pos - sent_pos
225            );
226            #[cfg(feature = "std")]
227            {
228                self.socket
229                    .write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
230                    .await?;
231            }
232            #[cfg(not(feature = "std"))]
233            {
234                self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
235            }
236        }
237
238        self.pos = 0;
239        Ok(())
240    }
241
242    /// The underlying write half of the socket.
243    pub fn write_half(&self) -> &Write {
244        &self.socket
245    }
246
247    pub(super) async fn write<T>(
248        &mut self,
249        value: &T,
250        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
251    ) -> crate::Result<()>
252    where
253        T: Serialize + ?Sized + Debug,
254    {
255        #[cfg(feature = "std")]
256        {
257            self.enqueue(value, fds)?;
258        }
259        #[cfg(not(feature = "std"))]
260        {
261            self.enqueue(value)?;
262        }
263        self.flush().await
264    }
265
266    pub(super) fn enqueue<T>(
267        &mut self,
268        value: &T,
269        #[cfg(feature = "std")] fds: Vec<OwnedFd>,
270    ) -> crate::Result<()>
271    where
272        T: Serialize + ?Sized + Debug,
273    {
274        #[cfg(feature = "std")]
275        let start_pos = self.pos;
276
277        let len = loop {
278            match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
279                Ok(len) => break len,
280                Err(crate::json_ser::Error::BufferTooSmall) => {
281                    // Buffer too small, grow it and retry
282                    self.grow_buffer()?;
283                }
284                Err(crate::json_ser::Error::KeyMustBeAString) => {
285                    // Actual serialization error
286                    // Convert to serde_json::Error for public API
287                    return Err(crate::Error::Json(serde::ser::Error::custom(
288                        "key must be a string",
289                    )));
290                }
291            }
292        };
293
294        // Add null terminator after this message.
295        if self.pos + len == self.buffer.len() {
296            self.grow_buffer()?;
297        }
298        self.buffer[self.pos + len] = b'\0';
299        self.pos += len + 1;
300
301        // Store FDs with message offset and length.
302        #[cfg(feature = "std")]
303        if !fds.is_empty() {
304            self.pending_fds.push_back(MessageFds {
305                offset: start_pos,
306                len: len + 1, // Include null terminator.
307                fds,
308            });
309        }
310
311        Ok(())
312    }
313
314    fn grow_buffer(&mut self) -> crate::Result<()> {
315        if self.buffer.len() >= super::MAX_BUFFER_SIZE {
316            return Err(crate::Error::BufferOverflow);
317        }
318
319        self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
320
321        Ok(())
322    }
323}
324
325/// Information about file descriptors pending to be sent with a message.
326#[cfg(feature = "std")]
327#[derive(Debug)]
328struct MessageFds {
329    /// File descriptors to send with this message.
330    fds: Vec<OwnedFd>,
331    /// Buffer offset where the message starts.
332    offset: usize,
333    /// Length of the message including the null terminator.
334    len: usize,
335}