Skip to main content

zlink_core/connection/
mod.rs

1//! Contains connection related API.
2//!
3//! The [`Connection`] type provides a low-level API for sending and receiving Varlink messages.
4//! For most use cases, you'll want to use the higher-level [`proxy`] and [`service`] attribute
5//! macros instead, which generate type-safe client and server code respectively.
6//!
7//! # Client Usage with `proxy` Macro
8//!
9//! The [`proxy`] macro generates methods on `Connection<S>` for calling remote service methods:
10//!
11//! ```
12//! #[zlink_core::proxy(
13//!     interface = "org.example.Calculator",
14//!     // Not needed in the real code because you'll use `proxy` through `zlink` crate.
15//!     crate = "zlink_core",
16//! )]
17//! trait CalculatorProxy {
18//!     async fn add(&mut self, a: f64, b: f64) -> zlink_core::Result<Result<f64, CalcError>>;
19//! }
20//!
21//! #[derive(Debug, zlink_core::ReplyError)]
22//! #[zlink(
23//!     interface = "org.example.Calculator",
24//!     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
25//!     crate = "zlink_core",
26//! )]
27//! enum CalcError {}
28//! ```
29//!
30//! # Server Usage with `service` Macro
31//!
32//! The [`service`] macro generates the [`Service`] trait implementation. See the [`service`] macro
33//! documentation for details and examples.
34//!
35//! # Low-Level API
36//!
37//! For advanced use cases that require more control, the [`Connection`] type provides direct access
38//! to message sending and receiving via methods like [`Connection::send_call`],
39//! [`Connection::receive_reply`], and [`Connection::chain_call`] for pipelining.
40//!
41//! [`proxy`]: macro@crate::proxy
42//! [`service`]: macro@crate::service
43//! [`Service`]: crate::service::Service
44
45#[cfg(feature = "std")]
46mod credentials;
47mod read_connection;
48#[cfg(feature = "std")]
49pub use credentials::Credentials;
50pub use read_connection::ReadConnection;
51#[cfg(feature = "std")]
52pub use rustix::{process::Pid, process::Uid};
53pub mod chain;
54pub mod socket;
55#[cfg(test)]
56mod tests;
57mod write_connection;
58use crate::{
59    reply::{self, Reply},
60    Call, Result,
61};
62#[cfg(feature = "std")]
63use alloc::vec;
64pub use chain::Chain;
65use core::{fmt::Debug, sync::atomic::AtomicUsize};
66#[cfg(feature = "std")]
67use socket::FetchPeerCredentials;
68pub use write_connection::WriteConnection;
69
70use serde::{Deserialize, Serialize};
71pub use socket::Socket;
72
73// Type alias for receive methods - std returns FDs, no_std doesn't
74#[cfg(feature = "std")]
75type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
76#[cfg(not(feature = "std"))]
77type RecvResult<T> = T;
78
79/// A connection.
80///
81/// The low-level API to send and receive messages.
82///
83/// Each connection gets a unique identifier when created that can be queried using
84/// [`Connection::id`]. This ID is shared between the read and write halves of the connection. It
85/// can be used to associate the read and write halves of the same connection.
86///
87/// # Cancel safety
88///
89/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
90/// documentation.
91#[derive(Debug)]
92pub struct Connection<S: Socket> {
93    read: ReadConnection<S::ReadHalf>,
94    write: WriteConnection<S::WriteHalf>,
95    #[cfg(feature = "std")]
96    credentials: Option<std::sync::Arc<Credentials>>,
97}
98
99impl<S> Connection<S>
100where
101    S: Socket,
102{
103    /// Create a new connection.
104    pub fn new(socket: S) -> Self {
105        let (read, write) = socket.split();
106        let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
107        Self {
108            read: ReadConnection::new(read, id),
109            write: WriteConnection::new(write, id),
110            #[cfg(feature = "std")]
111            credentials: None,
112        }
113    }
114
115    /// The reference to the read half of the connection.
116    pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
117        &self.read
118    }
119
120    /// The mutable reference to the read half of the connection.
121    pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
122        &mut self.read
123    }
124
125    /// The reference to the write half of the connection.
126    pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
127        &self.write
128    }
129
130    /// The mutable reference to the write half of the connection.
131    pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
132        &mut self.write
133    }
134
135    /// Split the connection into read and write halves.
136    ///
137    /// Note: This consumes any cached credentials. If you need the credentials after splitting,
138    /// call [`Connection::peer_credentials`] before splitting.
139    pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
140        (self.read, self.write)
141    }
142
143    /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
144    pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
145        Self {
146            read,
147            write,
148            #[cfg(feature = "std")]
149            credentials: None,
150        }
151    }
152
153    /// The unique identifier of the connection.
154    pub fn id(&self) -> usize {
155        assert_eq!(self.read.id(), self.write.id());
156        self.read.id()
157    }
158
159    /// Sends a method call.
160    ///
161    /// Convenience wrapper around [`WriteConnection::send_call`].
162    pub async fn send_call<Method>(
163        &mut self,
164        call: &Call<Method>,
165        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
166    ) -> Result<()>
167    where
168        Method: Serialize + Debug,
169    {
170        #[cfg(feature = "std")]
171        {
172            self.write.send_call(call, fds).await
173        }
174        #[cfg(not(feature = "std"))]
175        {
176            self.write.send_call(call).await
177        }
178    }
179
180    /// Receives a method call reply.
181    ///
182    /// Convenience wrapper around [`ReadConnection::receive_reply`].
183    pub async fn receive_reply<'r, ReplyParams, ReplyError>(
184        &'r mut self,
185    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
186    where
187        ReplyParams: Deserialize<'r> + Debug,
188        ReplyError: Deserialize<'r> + Debug,
189    {
190        self.read.receive_reply().await
191    }
192
193    /// Call a method and receive a reply.
194    ///
195    /// This is a convenience method that combines [`Connection::send_call`] and
196    /// [`Connection::receive_reply`].
197    pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
198        &'r mut self,
199        call: &Call<Method>,
200        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
201    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
202    where
203        Method: Serialize + Debug,
204        ReplyParams: Deserialize<'r> + Debug,
205        ReplyError: Deserialize<'r> + Debug,
206    {
207        #[cfg(feature = "std")]
208        self.send_call(call, fds).await?;
209        #[cfg(not(feature = "std"))]
210        self.send_call(call).await?;
211
212        self.receive_reply().await
213    }
214
215    /// Receive a method call over the socket.
216    ///
217    /// Convenience wrapper around [`ReadConnection::receive_call`].
218    pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
219    where
220        Method: Deserialize<'m> + Debug,
221    {
222        self.read.receive_call().await
223    }
224
225    /// Send a reply over the socket.
226    ///
227    /// Convenience wrapper around [`WriteConnection::send_reply`].
228    pub async fn send_reply<ReplyParams>(
229        &mut self,
230        reply: &Reply<ReplyParams>,
231        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
232    ) -> Result<()>
233    where
234        ReplyParams: Serialize + Debug,
235    {
236        #[cfg(feature = "std")]
237        {
238            self.write.send_reply(reply, fds).await
239        }
240        #[cfg(not(feature = "std"))]
241        {
242            self.write.send_reply(reply).await
243        }
244    }
245
246    /// Send an error reply over the socket.
247    ///
248    /// Convenience wrapper around [`WriteConnection::send_error`].
249    pub async fn send_error<ReplyError>(
250        &mut self,
251        error: &ReplyError,
252        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
253    ) -> Result<()>
254    where
255        ReplyError: Serialize + Debug,
256    {
257        #[cfg(feature = "std")]
258        {
259            self.write.send_error(error, fds).await
260        }
261        #[cfg(not(feature = "std"))]
262        {
263            self.write.send_error(error).await
264        }
265    }
266
267    /// Enqueue a call to the server.
268    ///
269    /// Convenience wrapper around [`WriteConnection::enqueue_call`].
270    pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
271    where
272        Method: Serialize + Debug,
273    {
274        #[cfg(feature = "std")]
275        {
276            self.write.enqueue_call(method, vec![])
277        }
278        #[cfg(not(feature = "std"))]
279        {
280            self.write.enqueue_call(method)
281        }
282    }
283
284    /// Flush the connection.
285    ///
286    /// Convenience wrapper around [`WriteConnection::flush`].
287    pub async fn flush(&mut self) -> Result<()> {
288        self.write.flush().await
289    }
290
291    /// Start a chain of method calls.
292    ///
293    /// This allows batching multiple calls together and sending them in a single write operation.
294    ///
295    /// # Examples
296    ///
297    /// ## Basic Usage with Sequential Access
298    ///
299    /// ```no_run
300    /// use zlink_core::{Connection, Call, reply};
301    /// use serde::{Serialize, Deserialize};
302    /// use serde_prefix_all::prefix_all;
303    /// use futures_util::{pin_mut, stream::StreamExt};
304    ///
305    /// # async fn example() -> zlink_core::Result<()> {
306    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
307    ///
308    /// #[prefix_all("org.example.")]
309    /// #[derive(Debug, Serialize, Deserialize)]
310    /// #[serde(tag = "method", content = "parameters")]
311    /// enum Methods {
312    ///     GetUser { id: u32 },
313    ///     GetProject { id: u32 },
314    /// }
315    ///
316    /// #[derive(Debug, Deserialize)]
317    /// struct User { name: String }
318    ///
319    /// #[derive(Debug, Deserialize)]
320    /// struct Project { title: String }
321    ///
322    /// #[derive(Debug, zlink_core::ReplyError)]
323    /// #[zlink(
324    ///     interface = "org.example",
325    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
326    ///     crate = "zlink_core",
327    /// )]
328    /// enum ApiError {
329    ///     UserNotFound { code: i32 },
330    ///     ProjectNotFound { code: i32 },
331    /// }
332    ///
333    /// let get_user = Call::new(Methods::GetUser { id: 1 });
334    /// let get_project = Call::new(Methods::GetProject { id: 2 });
335    ///
336    /// // Chain calls and send them in a batch
337    /// # #[cfg(feature = "std")]
338    /// let replies = conn
339    ///     .chain_call::<Methods>(&get_user, vec![])?
340    ///     .append(&get_project, vec![])?
341    ///     .send::<User, ApiError>().await?;
342    /// # #[cfg(not(feature = "std"))]
343    /// # let replies = conn
344    /// #     .chain_call::<Methods>(&get_user)?
345    /// #     .append(&get_project)?
346    /// #     .send::<User, ApiError>().await?;
347    /// pin_mut!(replies);
348    ///
349    /// // Access replies sequentially.
350    /// # #[cfg(feature = "std")]
351    /// # {
352    /// let (user_reply, _fds) = replies.next().await.unwrap()?;
353    /// let (project_reply, _fds) = replies.next().await.unwrap()?;
354    ///
355    /// match user_reply {
356    ///     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
357    ///     Err(error) => println!("User error: {:?}", error),
358    /// }
359    /// # }
360    /// # #[cfg(not(feature = "std"))]
361    /// # {
362    /// # let user_reply = replies.next().await.unwrap()?;
363    /// # let _project_reply = replies.next().await.unwrap()?;
364    /// #
365    /// # match user_reply {
366    /// #     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
367    /// #     Err(error) => println!("User error: {:?}", error),
368    /// # }
369    /// # }
370    /// # Ok(())
371    /// # }
372    /// ```
373    ///
374    /// ## Arbitrary Number of Calls
375    ///
376    /// ```no_run
377    /// # use zlink_core::{Connection, Call, reply};
378    /// # use serde::{Serialize, Deserialize};
379    /// # use futures_util::{pin_mut, stream::StreamExt};
380    /// # use serde_prefix_all::prefix_all;
381    /// # async fn example() -> zlink_core::Result<()> {
382    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
383    /// # #[prefix_all("org.example.")]
384    /// # #[derive(Debug, Serialize, Deserialize)]
385    /// # #[serde(tag = "method", content = "parameters")]
386    /// # enum Methods {
387    /// #     GetUser { id: u32 },
388    /// # }
389    /// # #[derive(Debug, Deserialize)]
390    /// # struct User { name: String }
391    /// # #[derive(Debug, zlink_core::ReplyError)]
392    /// #[zlink(
393    ///     interface = "org.example",
394    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
395    ///     crate = "zlink_core",
396    /// )]
397    /// # enum ApiError {
398    /// #     UserNotFound { code: i32 },
399    /// #     ProjectNotFound { code: i32 },
400    /// # }
401    /// # let get_user = Call::new(Methods::GetUser { id: 1 });
402    ///
403    /// // Chain many calls (no upper limit)
404    /// # #[cfg(feature = "std")]
405    /// let mut chain = conn.chain_call::<Methods>(&get_user, vec![])?;
406    /// # #[cfg(not(feature = "std"))]
407    /// # let mut chain = conn.chain_call::<Methods>(&get_user)?;
408    /// # #[cfg(feature = "std")]
409    /// for i in 2..100 {
410    ///     chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
411    /// }
412    /// # #[cfg(not(feature = "std"))]
413    /// # for i in 2..100 {
414    /// #     chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
415    /// # }
416    ///
417    /// let replies = chain.send::<User, ApiError>().await?;
418    /// pin_mut!(replies);
419    ///
420    /// // Process all replies sequentially.
421    /// # #[cfg(feature = "std")]
422    /// while let Some(result) = replies.next().await {
423    ///     let (user_reply, _fds) = result?;
424    ///     // Handle each reply...
425    ///     match user_reply {
426    ///         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
427    ///         Err(error) => println!("Error: {:?}", error),
428    ///     }
429    /// }
430    /// # #[cfg(not(feature = "std"))]
431    /// # while let Some(result) = replies.next().await {
432    /// #     let user_reply = result?;
433    /// #     // Handle each reply...
434    /// #     match user_reply {
435    /// #         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
436    /// #         Err(error) => println!("Error: {:?}", error),
437    /// #     }
438    /// # }
439    /// # Ok(())
440    /// # }
441    /// ```
442    ///
443    /// # Performance Benefits
444    ///
445    /// Instead of multiple write operations, the chain sends all calls in a single
446    /// write operation, reducing context switching and therefore minimizing latency.
447    pub fn chain_call<'c, Method>(
448        &'c mut self,
449        call: &Call<Method>,
450        #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
451    ) -> Result<Chain<'c, S>>
452    where
453        Method: Serialize + Debug,
454    {
455        Chain::new(
456            self,
457            call,
458            #[cfg(feature = "std")]
459            fds,
460        )
461    }
462
463    /// Create a chain from an iterator of method calls.
464    ///
465    /// This allows creating a chain from any iterator yielding method types or calls. Each item
466    /// is automatically converted to a [`Call`] via [`Into<Call<Method>>`]. Unlike
467    /// [`Connection::chain_call`], this method allows building chains from dynamically-sized
468    /// collections.
469    ///
470    /// # Errors
471    ///
472    /// Returns [`Error::EmptyChain`] if the iterator is empty.
473    ///
474    /// # Examples
475    ///
476    /// ```no_run
477    /// use zlink_core::Connection;
478    /// use serde::{Serialize, Deserialize};
479    /// use serde_prefix_all::prefix_all;
480    /// use futures_util::{pin_mut, stream::StreamExt};
481    ///
482    /// # async fn example() -> zlink_core::Result<()> {
483    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
484    ///
485    /// #[prefix_all("org.example.")]
486    /// #[derive(Debug, Serialize, Deserialize)]
487    /// #[serde(tag = "method", content = "parameters")]
488    /// enum Methods {
489    ///     GetUser { id: u32 },
490    /// }
491    ///
492    /// #[derive(Debug, Deserialize)]
493    /// struct User { name: String }
494    ///
495    /// #[derive(Debug, zlink_core::ReplyError)]
496    /// #[zlink(interface = "org.example", crate = "zlink_core")]
497    /// enum ApiError {
498    ///     UserNotFound { code: i32 },
499    /// }
500    ///
501    /// let user_ids = [1, 2, 3, 4, 5];
502    /// let replies = conn
503    ///     .chain_from_iter::<Methods, _, _>(
504    ///         user_ids.iter().map(|&id| Methods::GetUser { id })
505    ///     )?
506    ///     .send::<User, ApiError>()
507    ///     .await?;
508    /// pin_mut!(replies);
509    ///
510    /// # #[cfg(feature = "std")]
511    /// while let Some(result) = replies.next().await {
512    ///     let (user_reply, _fds) = result?;
513    ///     // Handle each reply...
514    /// }
515    /// # Ok(())
516    /// # }
517    /// ```
518    ///
519    /// [`Error::EmptyChain`]: crate::Error::EmptyChain
520    pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
521        &'c mut self,
522        calls: MethodCalls,
523    ) -> Result<Chain<'c, S>>
524    where
525        Method: Serialize + Debug,
526        MethodCall: Into<Call<Method>>,
527        MethodCalls: IntoIterator<Item = MethodCall>,
528    {
529        let mut iter = calls.into_iter();
530        let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
531
532        #[cfg(feature = "std")]
533        let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
534        #[cfg(not(feature = "std"))]
535        let mut chain = Chain::new(self, &first)?;
536
537        for call in iter {
538            let call: Call<Method> = call.into();
539            #[cfg(feature = "std")]
540            {
541                chain = chain.append(&call, alloc::vec::Vec::new())?;
542            }
543            #[cfg(not(feature = "std"))]
544            {
545                chain = chain.append(&call)?;
546            }
547        }
548
549        Ok(chain)
550    }
551
552    /// Create a chain from an iterator of method calls with file descriptors.
553    ///
554    /// Similar to [`Connection::chain_from_iter`], but allows passing file descriptors with each
555    /// call. Each item in the iterator is a tuple of a method type (or [`Call`]) and its
556    /// associated file descriptors.
557    ///
558    /// # Errors
559    ///
560    /// Returns [`Error::EmptyChain`] if the iterator is empty.
561    ///
562    /// # Examples
563    ///
564    /// ```no_run
565    /// use zlink_core::Connection;
566    /// use serde::{Serialize, Deserialize};
567    /// use serde_prefix_all::prefix_all;
568    /// use std::os::fd::OwnedFd;
569    ///
570    /// # async fn example() -> zlink_core::Result<()> {
571    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
572    ///
573    /// #[prefix_all("org.example.")]
574    /// #[derive(Debug, Serialize, Deserialize)]
575    /// #[serde(tag = "method", content = "parameters")]
576    /// enum Methods {
577    ///     SendFile { name: String },
578    /// }
579    ///
580    /// #[derive(Debug, Deserialize)]
581    /// struct FileResult { success: bool }
582    ///
583    /// #[derive(Debug, zlink_core::ReplyError)]
584    /// #[zlink(interface = "org.example", crate = "zlink_core")]
585    /// enum ApiError {
586    ///     SendFailed { reason: String },
587    /// }
588    ///
589    /// let calls_with_fds: Vec<(Methods, Vec<OwnedFd>)> = vec![
590    ///     (Methods::SendFile { name: "file1.txt".into() }, vec![/* fd1 */]),
591    ///     (Methods::SendFile { name: "file2.txt".into() }, vec![/* fd2 */]),
592    /// ];
593    ///
594    /// let replies = conn
595    ///     .chain_from_iter_with_fds::<Methods, _, _>(calls_with_fds)?
596    ///     .send::<FileResult, ApiError>()
597    ///     .await?;
598    /// # Ok(())
599    /// # }
600    /// ```
601    ///
602    /// [`Error::EmptyChain`]: crate::Error::EmptyChain
603    #[cfg(feature = "std")]
604    pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
605        &'c mut self,
606        calls: MethodCalls,
607    ) -> Result<Chain<'c, S>>
608    where
609        Method: Serialize + Debug,
610        MethodCall: Into<Call<Method>>,
611        MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
612    {
613        let mut iter = calls.into_iter();
614        let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
615        let first: Call<Method> = first.into();
616        let mut chain = Chain::new(self, &first, first_fds)?;
617
618        for (call, fds) in iter {
619            let call: Call<Method> = call.into();
620            chain = chain.append(&call, fds)?;
621        }
622
623        Ok(chain)
624    }
625
626    /// Get the peer credentials.
627    ///
628    /// This method caches the credentials on the first call.
629    #[cfg(feature = "std")]
630    pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
631    where
632        S::ReadHalf: socket::FetchPeerCredentials,
633    {
634        if self.credentials.is_none() {
635            let creds = self.read.read_half().fetch_peer_credentials().await?;
636            self.credentials = Some(std::sync::Arc::new(creds));
637        }
638
639        // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
640        // method doesn't error out.
641        Ok(self.credentials.as_ref().unwrap())
642    }
643}
644
645impl<S> From<S> for Connection<S>
646where
647    S: Socket,
648{
649    fn from(socket: S) -> Self {
650        Self::new(socket)
651    }
652}
653
654pub(crate) const BUFFER_SIZE: usize = 256;
655const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
656
657static NEXT_ID: AtomicUsize = AtomicUsize::new(0);