zlink_core/connection/
mod.rs

1//! Contains connection related API.
2
3#[cfg(feature = "std")]
4mod credentials;
5mod read_connection;
6#[cfg(feature = "std")]
7pub use credentials::Credentials;
8pub use read_connection::ReadConnection;
9#[cfg(feature = "std")]
10pub use rustix::{process::Pid, process::Uid};
11pub mod chain;
12pub mod socket;
13#[cfg(test)]
14mod tests;
15mod write_connection;
16use crate::{
17    reply::{self, Reply},
18    Call, Result,
19};
20#[cfg(feature = "std")]
21use alloc::vec;
22pub use chain::Chain;
23use core::{fmt::Debug, sync::atomic::AtomicUsize};
24#[cfg(feature = "std")]
25use socket::FetchPeerCredentials;
26pub use write_connection::WriteConnection;
27
28use serde::{Deserialize, Serialize};
29pub use socket::Socket;
30
31// Type alias for receive methods - std returns FDs, no_std doesn't
32#[cfg(feature = "std")]
33type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
34#[cfg(not(feature = "std"))]
35type RecvResult<T> = T;
36
37/// A connection.
38///
39/// The low-level API to send and receive messages.
40///
41/// Each connection gets a unique identifier when created that can be queried using
42/// [`Connection::id`]. This ID is shared between the read and write halves of the connection. It
43/// can be used to associate the read and write halves of the same connection.
44///
45/// # Cancel safety
46///
47/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
48/// documentation.
49#[derive(Debug)]
50pub struct Connection<S: Socket> {
51    read: ReadConnection<S::ReadHalf>,
52    write: WriteConnection<S::WriteHalf>,
53    #[cfg(feature = "std")]
54    credentials: Option<std::sync::Arc<Credentials>>,
55}
56
57impl<S> Connection<S>
58where
59    S: Socket,
60{
61    /// Create a new connection.
62    pub fn new(socket: S) -> Self {
63        let (read, write) = socket.split();
64        let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
65        Self {
66            read: ReadConnection::new(read, id),
67            write: WriteConnection::new(write, id),
68            #[cfg(feature = "std")]
69            credentials: None,
70        }
71    }
72
73    /// The reference to the read half of the connection.
74    pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
75        &self.read
76    }
77
78    /// The mutable reference to the read half of the connection.
79    pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
80        &mut self.read
81    }
82
83    /// The reference to the write half of the connection.
84    pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
85        &self.write
86    }
87
88    /// The mutable reference to the write half of the connection.
89    pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
90        &mut self.write
91    }
92
93    /// Split the connection into read and write halves.
94    ///
95    /// Note: This consumes any cached credentials. If you need the credentials after splitting,
96    /// call [`Connection::peer_credentials`] before splitting.
97    pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
98        (self.read, self.write)
99    }
100
101    /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
102    pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
103        Self {
104            read,
105            write,
106            #[cfg(feature = "std")]
107            credentials: None,
108        }
109    }
110
111    /// The unique identifier of the connection.
112    pub fn id(&self) -> usize {
113        assert_eq!(self.read.id(), self.write.id());
114        self.read.id()
115    }
116
117    /// Sends a method call.
118    ///
119    /// Convenience wrapper around [`WriteConnection::send_call`].
120    pub async fn send_call<Method>(
121        &mut self,
122        call: &Call<Method>,
123        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
124    ) -> Result<()>
125    where
126        Method: Serialize + Debug,
127    {
128        #[cfg(feature = "std")]
129        {
130            self.write.send_call(call, fds).await
131        }
132        #[cfg(not(feature = "std"))]
133        {
134            self.write.send_call(call).await
135        }
136    }
137
138    /// Receives a method call reply.
139    ///
140    /// Convenience wrapper around [`ReadConnection::receive_reply`].
141    pub async fn receive_reply<'r, ReplyParams, ReplyError>(
142        &'r mut self,
143    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
144    where
145        ReplyParams: Deserialize<'r> + Debug,
146        ReplyError: Deserialize<'r> + Debug,
147    {
148        self.read.receive_reply().await
149    }
150
151    /// Call a method and receive a reply.
152    ///
153    /// This is a convenience method that combines [`Connection::send_call`] and
154    /// [`Connection::receive_reply`].
155    pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
156        &'r mut self,
157        call: &Call<Method>,
158        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
159    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
160    where
161        Method: Serialize + Debug,
162        ReplyParams: Deserialize<'r> + Debug,
163        ReplyError: Deserialize<'r> + Debug,
164    {
165        #[cfg(feature = "std")]
166        self.send_call(call, fds).await?;
167        #[cfg(not(feature = "std"))]
168        self.send_call(call).await?;
169
170        self.receive_reply().await
171    }
172
173    /// Receive a method call over the socket.
174    ///
175    /// Convenience wrapper around [`ReadConnection::receive_call`].
176    pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
177    where
178        Method: Deserialize<'m> + Debug,
179    {
180        self.read.receive_call().await
181    }
182
183    /// Send a reply over the socket.
184    ///
185    /// Convenience wrapper around [`WriteConnection::send_reply`].
186    pub async fn send_reply<ReplyParams>(
187        &mut self,
188        reply: &Reply<ReplyParams>,
189        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
190    ) -> Result<()>
191    where
192        ReplyParams: Serialize + Debug,
193    {
194        #[cfg(feature = "std")]
195        {
196            self.write.send_reply(reply, fds).await
197        }
198        #[cfg(not(feature = "std"))]
199        {
200            self.write.send_reply(reply).await
201        }
202    }
203
204    /// Send an error reply over the socket.
205    ///
206    /// Convenience wrapper around [`WriteConnection::send_error`].
207    pub async fn send_error<ReplyError>(
208        &mut self,
209        error: &ReplyError,
210        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
211    ) -> Result<()>
212    where
213        ReplyError: Serialize + Debug,
214    {
215        #[cfg(feature = "std")]
216        {
217            self.write.send_error(error, fds).await
218        }
219        #[cfg(not(feature = "std"))]
220        {
221            self.write.send_error(error).await
222        }
223    }
224
225    /// Enqueue a call to the server.
226    ///
227    /// Convenience wrapper around [`WriteConnection::enqueue_call`].
228    pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
229    where
230        Method: Serialize + Debug,
231    {
232        #[cfg(feature = "std")]
233        {
234            self.write.enqueue_call(method, vec![])
235        }
236        #[cfg(not(feature = "std"))]
237        {
238            self.write.enqueue_call(method)
239        }
240    }
241
242    /// Flush the connection.
243    ///
244    /// Convenience wrapper around [`WriteConnection::flush`].
245    pub async fn flush(&mut self) -> Result<()> {
246        self.write.flush().await
247    }
248
249    /// Start a chain of method calls.
250    ///
251    /// This allows batching multiple calls together and sending them in a single write operation.
252    ///
253    /// # Examples
254    ///
255    /// ## Basic Usage with Sequential Access
256    ///
257    /// ```no_run
258    /// use zlink_core::{Connection, Call, reply};
259    /// use serde::{Serialize, Deserialize};
260    /// use serde_prefix_all::prefix_all;
261    /// use futures_util::{pin_mut, stream::StreamExt};
262    ///
263    /// # async fn example() -> zlink_core::Result<()> {
264    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
265    ///
266    /// #[prefix_all("org.example.")]
267    /// #[derive(Debug, Serialize, Deserialize)]
268    /// #[serde(tag = "method", content = "parameters")]
269    /// enum Methods {
270    ///     GetUser { id: u32 },
271    ///     GetProject { id: u32 },
272    /// }
273    ///
274    /// #[derive(Debug, Deserialize)]
275    /// struct User { name: String }
276    ///
277    /// #[derive(Debug, Deserialize)]
278    /// struct Project { title: String }
279    ///
280    /// #[derive(Debug, zlink_core::ReplyError)]
281    /// #[zlink(
282    ///     interface = "org.example",
283    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
284    ///     crate = "zlink_core",
285    /// )]
286    /// enum ApiError {
287    ///     UserNotFound { code: i32 },
288    ///     ProjectNotFound { code: i32 },
289    /// }
290    ///
291    /// let get_user = Call::new(Methods::GetUser { id: 1 });
292    /// let get_project = Call::new(Methods::GetProject { id: 2 });
293    ///
294    /// // Chain calls and send them in a batch
295    /// # #[cfg(feature = "std")]
296    /// let replies = conn
297    ///     .chain_call::<Methods>(&get_user, vec![])?
298    ///     .append(&get_project, vec![])?
299    ///     .send::<User, ApiError>().await?;
300    /// # #[cfg(not(feature = "std"))]
301    /// # let replies = conn
302    /// #     .chain_call::<Methods>(&get_user)?
303    /// #     .append(&get_project)?
304    /// #     .send::<User, ApiError>().await?;
305    /// pin_mut!(replies);
306    ///
307    /// // Access replies sequentially.
308    /// # #[cfg(feature = "std")]
309    /// # {
310    /// let (user_reply, _fds) = replies.next().await.unwrap()?;
311    /// let (project_reply, _fds) = replies.next().await.unwrap()?;
312    ///
313    /// match user_reply {
314    ///     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
315    ///     Err(error) => println!("User error: {:?}", error),
316    /// }
317    /// # }
318    /// # #[cfg(not(feature = "std"))]
319    /// # {
320    /// # let user_reply = replies.next().await.unwrap()?;
321    /// # let _project_reply = replies.next().await.unwrap()?;
322    /// #
323    /// # match user_reply {
324    /// #     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
325    /// #     Err(error) => println!("User error: {:?}", error),
326    /// # }
327    /// # }
328    /// # Ok(())
329    /// # }
330    /// ```
331    ///
332    /// ## Arbitrary Number of Calls
333    ///
334    /// ```no_run
335    /// # use zlink_core::{Connection, Call, reply};
336    /// # use serde::{Serialize, Deserialize};
337    /// # use futures_util::{pin_mut, stream::StreamExt};
338    /// # use serde_prefix_all::prefix_all;
339    /// # async fn example() -> zlink_core::Result<()> {
340    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
341    /// # #[prefix_all("org.example.")]
342    /// # #[derive(Debug, Serialize, Deserialize)]
343    /// # #[serde(tag = "method", content = "parameters")]
344    /// # enum Methods {
345    /// #     GetUser { id: u32 },
346    /// # }
347    /// # #[derive(Debug, Deserialize)]
348    /// # struct User { name: String }
349    /// # #[derive(Debug, zlink_core::ReplyError)]
350    /// #[zlink(
351    ///     interface = "org.example",
352    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
353    ///     crate = "zlink_core",
354    /// )]
355    /// # enum ApiError {
356    /// #     UserNotFound { code: i32 },
357    /// #     ProjectNotFound { code: i32 },
358    /// # }
359    /// # let get_user = Call::new(Methods::GetUser { id: 1 });
360    ///
361    /// // Chain many calls (no upper limit)
362    /// # #[cfg(feature = "std")]
363    /// let mut chain = conn.chain_call::<Methods>(&get_user, vec![])?;
364    /// # #[cfg(not(feature = "std"))]
365    /// # let mut chain = conn.chain_call::<Methods>(&get_user)?;
366    /// # #[cfg(feature = "std")]
367    /// for i in 2..100 {
368    ///     chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
369    /// }
370    /// # #[cfg(not(feature = "std"))]
371    /// # for i in 2..100 {
372    /// #     chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
373    /// # }
374    ///
375    /// let replies = chain.send::<User, ApiError>().await?;
376    /// pin_mut!(replies);
377    ///
378    /// // Process all replies sequentially.
379    /// # #[cfg(feature = "std")]
380    /// while let Some(result) = replies.next().await {
381    ///     let (user_reply, _fds) = result?;
382    ///     // Handle each reply...
383    ///     match user_reply {
384    ///         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
385    ///         Err(error) => println!("Error: {:?}", error),
386    ///     }
387    /// }
388    /// # #[cfg(not(feature = "std"))]
389    /// # while let Some(result) = replies.next().await {
390    /// #     let user_reply = result?;
391    /// #     // Handle each reply...
392    /// #     match user_reply {
393    /// #         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
394    /// #         Err(error) => println!("Error: {:?}", error),
395    /// #     }
396    /// # }
397    /// # Ok(())
398    /// # }
399    /// ```
400    ///
401    /// # Performance Benefits
402    ///
403    /// Instead of multiple write operations, the chain sends all calls in a single
404    /// write operation, reducing context switching and therefore minimizing latency.
405    pub fn chain_call<'c, Method>(
406        &'c mut self,
407        call: &Call<Method>,
408        #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
409    ) -> Result<Chain<'c, S>>
410    where
411        Method: Serialize + Debug,
412    {
413        Chain::new(
414            self,
415            call,
416            #[cfg(feature = "std")]
417            fds,
418        )
419    }
420
421    /// Create a chain from an iterator of method calls.
422    ///
423    /// This allows creating a chain from any iterator yielding method types or calls. Each item
424    /// is automatically converted to a [`Call`] via [`Into<Call<Method>>`]. Unlike
425    /// [`Connection::chain_call`], this method allows building chains from dynamically-sized
426    /// collections.
427    ///
428    /// # Errors
429    ///
430    /// Returns [`Error::EmptyChain`] if the iterator is empty.
431    ///
432    /// # Examples
433    ///
434    /// ```no_run
435    /// use zlink_core::Connection;
436    /// use serde::{Serialize, Deserialize};
437    /// use serde_prefix_all::prefix_all;
438    /// use futures_util::{pin_mut, stream::StreamExt};
439    ///
440    /// # async fn example() -> zlink_core::Result<()> {
441    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
442    ///
443    /// #[prefix_all("org.example.")]
444    /// #[derive(Debug, Serialize, Deserialize)]
445    /// #[serde(tag = "method", content = "parameters")]
446    /// enum Methods {
447    ///     GetUser { id: u32 },
448    /// }
449    ///
450    /// #[derive(Debug, Deserialize)]
451    /// struct User { name: String }
452    ///
453    /// #[derive(Debug, zlink_core::ReplyError)]
454    /// #[zlink(interface = "org.example", crate = "zlink_core")]
455    /// enum ApiError {
456    ///     UserNotFound { code: i32 },
457    /// }
458    ///
459    /// let user_ids = [1, 2, 3, 4, 5];
460    /// let replies = conn
461    ///     .chain_from_iter::<Methods, _, _>(
462    ///         user_ids.iter().map(|&id| Methods::GetUser { id })
463    ///     )?
464    ///     .send::<User, ApiError>()
465    ///     .await?;
466    /// pin_mut!(replies);
467    ///
468    /// # #[cfg(feature = "std")]
469    /// while let Some(result) = replies.next().await {
470    ///     let (user_reply, _fds) = result?;
471    ///     // Handle each reply...
472    /// }
473    /// # Ok(())
474    /// # }
475    /// ```
476    ///
477    /// [`Error::EmptyChain`]: crate::Error::EmptyChain
478    pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
479        &'c mut self,
480        calls: MethodCalls,
481    ) -> Result<Chain<'c, S>>
482    where
483        Method: Serialize + Debug,
484        MethodCall: Into<Call<Method>>,
485        MethodCalls: IntoIterator<Item = MethodCall>,
486    {
487        let mut iter = calls.into_iter();
488        let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
489
490        #[cfg(feature = "std")]
491        let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
492        #[cfg(not(feature = "std"))]
493        let mut chain = Chain::new(self, &first)?;
494
495        for call in iter {
496            let call: Call<Method> = call.into();
497            #[cfg(feature = "std")]
498            {
499                chain = chain.append(&call, alloc::vec::Vec::new())?;
500            }
501            #[cfg(not(feature = "std"))]
502            {
503                chain = chain.append(&call)?;
504            }
505        }
506
507        Ok(chain)
508    }
509
510    /// Create a chain from an iterator of method calls with file descriptors.
511    ///
512    /// Similar to [`Connection::chain_from_iter`], but allows passing file descriptors with each
513    /// call. Each item in the iterator is a tuple of a method type (or [`Call`]) and its
514    /// associated file descriptors.
515    ///
516    /// # Errors
517    ///
518    /// Returns [`Error::EmptyChain`] if the iterator is empty.
519    ///
520    /// # Examples
521    ///
522    /// ```no_run
523    /// use zlink_core::Connection;
524    /// use serde::{Serialize, Deserialize};
525    /// use serde_prefix_all::prefix_all;
526    /// use std::os::fd::OwnedFd;
527    ///
528    /// # async fn example() -> zlink_core::Result<()> {
529    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
530    ///
531    /// #[prefix_all("org.example.")]
532    /// #[derive(Debug, Serialize, Deserialize)]
533    /// #[serde(tag = "method", content = "parameters")]
534    /// enum Methods {
535    ///     SendFile { name: String },
536    /// }
537    ///
538    /// #[derive(Debug, Deserialize)]
539    /// struct FileResult { success: bool }
540    ///
541    /// #[derive(Debug, zlink_core::ReplyError)]
542    /// #[zlink(interface = "org.example", crate = "zlink_core")]
543    /// enum ApiError {
544    ///     SendFailed { reason: String },
545    /// }
546    ///
547    /// let calls_with_fds: Vec<(Methods, Vec<OwnedFd>)> = vec![
548    ///     (Methods::SendFile { name: "file1.txt".into() }, vec![/* fd1 */]),
549    ///     (Methods::SendFile { name: "file2.txt".into() }, vec![/* fd2 */]),
550    /// ];
551    ///
552    /// let replies = conn
553    ///     .chain_from_iter_with_fds::<Methods, _, _>(calls_with_fds)?
554    ///     .send::<FileResult, ApiError>()
555    ///     .await?;
556    /// # Ok(())
557    /// # }
558    /// ```
559    ///
560    /// [`Error::EmptyChain`]: crate::Error::EmptyChain
561    #[cfg(feature = "std")]
562    pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
563        &'c mut self,
564        calls: MethodCalls,
565    ) -> Result<Chain<'c, S>>
566    where
567        Method: Serialize + Debug,
568        MethodCall: Into<Call<Method>>,
569        MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
570    {
571        let mut iter = calls.into_iter();
572        let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
573        let first: Call<Method> = first.into();
574        let mut chain = Chain::new(self, &first, first_fds)?;
575
576        for (call, fds) in iter {
577            let call: Call<Method> = call.into();
578            chain = chain.append(&call, fds)?;
579        }
580
581        Ok(chain)
582    }
583
584    /// Get the peer credentials.
585    ///
586    /// This method caches the credentials on the first call.
587    #[cfg(feature = "std")]
588    pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
589    where
590        S::ReadHalf: socket::FetchPeerCredentials,
591    {
592        if self.credentials.is_none() {
593            let creds = self.read.read_half().fetch_peer_credentials().await?;
594            self.credentials = Some(std::sync::Arc::new(creds));
595        }
596
597        // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
598        // method doesn't error out.
599        Ok(self.credentials.as_ref().unwrap())
600    }
601}
602
603impl<S> From<S> for Connection<S>
604where
605    S: Socket,
606{
607    fn from(socket: S) -> Self {
608        Self::new(socket)
609    }
610}
611
612pub(crate) const BUFFER_SIZE: usize = 256;
613const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
614
615static NEXT_ID: AtomicUsize = AtomicUsize::new(0);