zlink_core/connection/write_connection.rs
1//! Contains connection related API.
2
3use core::fmt::Debug;
4
5#[cfg(feature = "std")]
6use alloc::collections::VecDeque;
7use alloc::vec::Vec;
8use serde::Serialize;
9
10use super::{BUFFER_SIZE, Call, Reply, socket::WriteHalf};
11
12#[cfg(feature = "std")]
13use std::os::fd::OwnedFd;
14
15/// A connection.
16///
17/// The low-level API to send messages.
18///
19/// # Cancel safety
20///
21/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
22/// documentation.
23#[derive(Debug)]
24pub struct WriteConnection<Write: WriteHalf> {
25 pub(super) socket: Write,
26 pub(super) buffer: Vec<u8>,
27 pub(super) pos: usize,
28 id: usize,
29 #[cfg(feature = "std")]
30 pending_fds: VecDeque<MessageFds>,
31 // On macOS, SCM_RIGHTS does not properly hold a reference to the underlying file description
32 // when the sender closes the FD in the same process before the receiver calls `recvmsg`. The
33 // receiver ends up with a stale FD. We work around this by keeping sent FDs alive until the
34 // read half confirms it has called `recvmsg` for them via `drain_held_fds`.
35 #[cfg(all(feature = "std", target_os = "macos"))]
36 pub(super) held_fds: VecDeque<Vec<OwnedFd>>,
37}
38
39impl<Write: WriteHalf> WriteConnection<Write> {
40 /// Create a new connection.
41 pub(super) fn new(socket: Write, id: usize) -> Self {
42 Self {
43 socket,
44 id,
45 buffer: alloc::vec![0; BUFFER_SIZE],
46 pos: 0,
47 #[cfg(feature = "std")]
48 pending_fds: VecDeque::new(),
49 #[cfg(all(feature = "std", target_os = "macos"))]
50 held_fds: VecDeque::new(),
51 }
52 }
53
54 /// The unique identifier of the connection.
55 #[inline]
56 pub fn id(&self) -> usize {
57 self.id
58 }
59
60 /// Sends a method call.
61 ///
62 /// The generic `Method` is the type of the method name and its input parameters. This should be
63 /// a type that can serialize itself to a complete method call message, i-e an object containing
64 /// `method` and `parameter` fields. This can be easily achieved using the `serde::Serialize`
65 /// derive:
66 ///
67 /// ```rust
68 /// use serde::{Serialize, Deserialize};
69 /// use serde_prefix_all::prefix_all;
70 ///
71 /// #[prefix_all("org.example.ftl.")]
72 /// #[derive(Debug, Serialize, Deserialize)]
73 /// #[serde(tag = "method", content = "parameters")]
74 /// enum MyMethods<'m> {
75 /// // The name needs to be the fully-qualified name of the error.
76 /// Alpha { param1: u32, param2: &'m str},
77 /// Bravo,
78 /// Charlie { param1: &'m str },
79 /// }
80 /// ```
81 ///
82 /// The `fds` parameter contains file descriptors to send along with the call.
83 pub async fn send_call<Method>(
84 &mut self,
85 call: &Call<Method>,
86 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
87 ) -> crate::Result<()>
88 where
89 Method: Serialize + Debug,
90 {
91 trace!("connection {}: sending call: {:?}", self.id, call);
92 #[cfg(feature = "std")]
93 {
94 self.write(call, fds).await
95 }
96 #[cfg(not(feature = "std"))]
97 {
98 self.write(call).await
99 }
100 }
101
102 /// Send a reply over the socket.
103 ///
104 /// The generic parameter `Params` is the type of the successful reply. This should be a type
105 /// that can serialize itself as the `parameters` field of the reply.
106 ///
107 /// The `fds` parameter contains file descriptors to send along with the reply.
108 pub async fn send_reply<Params>(
109 &mut self,
110 reply: &Reply<Params>,
111 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
112 ) -> crate::Result<()>
113 where
114 Params: Serialize + Debug,
115 {
116 trace!("connection {}: sending reply: {:?}", self.id, reply);
117 #[cfg(feature = "std")]
118 {
119 self.write(reply, fds).await
120 }
121 #[cfg(not(feature = "std"))]
122 {
123 self.write(reply).await
124 }
125 }
126
127 /// Send an error reply over the socket.
128 ///
129 /// The generic parameter `ReplyError` is the type of the error reply. This should be a type
130 /// that can serialize itself to the whole reply object, containing `error` and `parameter`
131 /// fields. This can be easily achieved using the `serde::Serialize` derive (See the code
132 /// snippet in [`super::ReadConnection::receive_reply`] documentation for an example).
133 ///
134 /// The `fds` parameter contains file descriptors to send along with the error.
135 pub async fn send_error<ReplyError>(
136 &mut self,
137 error: &ReplyError,
138 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
139 ) -> crate::Result<()>
140 where
141 ReplyError: Serialize + Debug,
142 {
143 trace!("connection {}: sending error: {:?}", self.id, error);
144 #[cfg(feature = "std")]
145 {
146 self.write(error, fds).await
147 }
148 #[cfg(not(feature = "std"))]
149 {
150 self.write(error).await
151 }
152 }
153
154 /// Enqueue a call to be sent over the socket.
155 ///
156 /// Similar to [`WriteConnection::send_call`], except that the call is not sent immediately but
157 /// enqueued for later sending. This is useful when you want to send multiple calls in a
158 /// batch.
159 ///
160 /// The `fds` parameter contains file descriptors to send along with the call.
161 pub fn enqueue_call<Method>(
162 &mut self,
163 call: &Call<Method>,
164 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
165 ) -> crate::Result<()>
166 where
167 Method: Serialize + Debug,
168 {
169 trace!("connection {}: enqueuing call: {:?}", self.id, call);
170 #[cfg(feature = "std")]
171 {
172 self.enqueue(call, fds)
173 }
174 #[cfg(not(feature = "std"))]
175 {
176 self.enqueue(call)
177 }
178 }
179
180 /// Send out the enqueued calls.
181 pub async fn flush(&mut self) -> crate::Result<()> {
182 if self.pos == 0 {
183 return Ok(());
184 }
185
186 #[allow(unused_mut)]
187 let mut sent_pos = 0;
188
189 #[cfg(feature = "std")]
190 {
191 // While there are FDs, send one message at a time.
192 while !self.pending_fds.is_empty() {
193 // Get the first FD entry.
194 let pending = self.pending_fds.front().unwrap();
195 let fd_offset = pending.offset;
196 let msg_len = pending.len;
197
198 // If there are bytes before the FD message, send them first without FDs.
199 if sent_pos < fd_offset {
200 trace!(
201 "connection {}: flushing {} bytes before FD message",
202 self.id,
203 fd_offset - sent_pos
204 );
205 self.socket
206 .write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
207 .await?;
208 }
209
210 // Send this message with its FDs.
211 let msg_end = fd_offset + msg_len;
212 let MessageFds {
213 fds,
214 offset: _,
215 len: _,
216 } = self.pending_fds.pop_front().unwrap();
217 trace!(
218 "connection {}: flushing {} bytes with {} FDs",
219 self.id,
220 msg_len,
221 fds.len()
222 );
223 self.socket
224 .write(&self.buffer[fd_offset..msg_end], &fds)
225 .await?;
226 sent_pos = msg_end;
227
228 // On macOS, keep sent FDs alive until the read half confirms receipt via
229 // `recvmsg`. See the comment on the `held_fds` field for details.
230 #[cfg(target_os = "macos")]
231 self.held_fds.push_back(fds);
232 }
233 }
234
235 // No more FDs, send all remaining bytes at once.
236 if sent_pos < self.pos {
237 trace!(
238 "connection {}: flushing {} bytes",
239 self.id,
240 self.pos - sent_pos
241 );
242 #[cfg(feature = "std")]
243 {
244 self.socket
245 .write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
246 .await?;
247 }
248 #[cfg(not(feature = "std"))]
249 {
250 self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
251 }
252 }
253
254 self.pos = 0;
255 Ok(())
256 }
257
258 /// The underlying write half of the socket.
259 pub fn write_half(&self) -> &Write {
260 &self.socket
261 }
262
263 pub(super) async fn write<T>(
264 &mut self,
265 value: &T,
266 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
267 ) -> crate::Result<()>
268 where
269 T: Serialize + ?Sized + Debug,
270 {
271 #[cfg(feature = "std")]
272 {
273 self.enqueue(value, fds)?;
274 }
275 #[cfg(not(feature = "std"))]
276 {
277 self.enqueue(value)?;
278 }
279 self.flush().await
280 }
281
282 pub(super) fn enqueue<T>(
283 &mut self,
284 value: &T,
285 #[cfg(feature = "std")] fds: Vec<OwnedFd>,
286 ) -> crate::Result<()>
287 where
288 T: Serialize + ?Sized + Debug,
289 {
290 #[cfg(feature = "std")]
291 let start_pos = self.pos;
292
293 let len = loop {
294 match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
295 Ok(len) => break len,
296 Err(crate::json_ser::Error::BufferTooSmall) => {
297 // Buffer too small, grow it and retry
298 self.grow_buffer()?;
299 }
300 Err(crate::json_ser::Error::KeyMustBeAString) => {
301 // Actual serialization error
302 // Convert to serde_json::Error for public API
303 return Err(crate::Error::Json(serde::ser::Error::custom(
304 "key must be a string",
305 )));
306 }
307 }
308 };
309
310 // Add null terminator after this message.
311 if self.pos + len == self.buffer.len() {
312 self.grow_buffer()?;
313 }
314 self.buffer[self.pos + len] = b'\0';
315 self.pos += len + 1;
316
317 // Store FDs with message offset and length.
318 #[cfg(feature = "std")]
319 if !fds.is_empty() {
320 self.pending_fds.push_back(MessageFds {
321 offset: start_pos,
322 len: len + 1, // Include null terminator.
323 fds,
324 });
325 }
326
327 Ok(())
328 }
329
330 fn grow_buffer(&mut self) -> crate::Result<()> {
331 if self.buffer.len() >= super::MAX_BUFFER_SIZE {
332 return Err(crate::Error::BufferOverflow);
333 }
334
335 self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
336
337 Ok(())
338 }
339}
340
341/// Information about file descriptors pending to be sent with a message.
342#[cfg(feature = "std")]
343#[derive(Debug)]
344struct MessageFds {
345 /// File descriptors to send with this message.
346 fds: Vec<OwnedFd>,
347 /// Buffer offset where the message starts.
348 offset: usize,
349 /// Length of the message including the null terminator.
350 len: usize,
351}