zlink_core/connection/
mod.rs

1//! Contains connection related API.
2
3#[cfg(feature = "std")]
4mod credentials;
5mod read_connection;
6#[cfg(feature = "std")]
7pub use credentials::Credentials;
8pub use read_connection::ReadConnection;
9#[cfg(feature = "std")]
10pub use rustix::{process::Pid, process::Uid};
11pub mod chain;
12pub mod socket;
13#[cfg(test)]
14mod tests;
15mod write_connection;
16use crate::{
17    reply::{self, Reply},
18    Call, Result,
19};
20#[cfg(feature = "std")]
21use alloc::vec;
22pub use chain::Chain;
23use core::{fmt::Debug, sync::atomic::AtomicUsize};
24#[cfg(feature = "std")]
25use socket::FetchPeerCredentials;
26pub use write_connection::WriteConnection;
27
28use serde::{Deserialize, Serialize};
29pub use socket::Socket;
30
31// Type alias for receive methods - std returns FDs, no_std doesn't
32#[cfg(feature = "std")]
33type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
34#[cfg(not(feature = "std"))]
35type RecvResult<T> = T;
36
37/// A connection.
38///
39/// The low-level API to send and receive messages.
40///
41/// Each connection gets a unique identifier when created that can be queried using
42/// [`Connection::id`]. This ID is shared betwen the read and write halves of the connection. It
43/// can be used to associate the read and write halves of the same connection.
44///
45/// # Cancel safety
46///
47/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
48/// documentation.
49#[derive(Debug)]
50pub struct Connection<S: Socket> {
51    read: ReadConnection<S::ReadHalf>,
52    write: WriteConnection<S::WriteHalf>,
53    #[cfg(feature = "std")]
54    credentials: Option<std::sync::Arc<Credentials>>,
55}
56
57impl<S> Connection<S>
58where
59    S: Socket,
60{
61    /// Create a new connection.
62    pub fn new(socket: S) -> Self {
63        let (read, write) = socket.split();
64        let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
65        Self {
66            read: ReadConnection::new(read, id),
67            write: WriteConnection::new(write, id),
68            #[cfg(feature = "std")]
69            credentials: None,
70        }
71    }
72
73    /// The reference to the read half of the connection.
74    pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
75        &self.read
76    }
77
78    /// The mutable reference to the read half of the connection.
79    pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
80        &mut self.read
81    }
82
83    /// The reference to the write half of the connection.
84    pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
85        &self.write
86    }
87
88    /// The mutable reference to the write half of the connection.
89    pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
90        &mut self.write
91    }
92
93    /// Split the connection into read and write halves.
94    ///
95    /// Note: This consumes any cached credentials. If you need the credentials after splitting,
96    /// call [`Connection::peer_credentials`] before splitting.
97    pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
98        (self.read, self.write)
99    }
100
101    /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
102    pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
103        Self {
104            read,
105            write,
106            #[cfg(feature = "std")]
107            credentials: None,
108        }
109    }
110
111    /// The unique identifier of the connection.
112    pub fn id(&self) -> usize {
113        assert_eq!(self.read.id(), self.write.id());
114        self.read.id()
115    }
116
117    /// Sends a method call.
118    ///
119    /// Convenience wrapper around [`WriteConnection::send_call`].
120    pub async fn send_call<Method>(
121        &mut self,
122        call: &Call<Method>,
123        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
124    ) -> Result<()>
125    where
126        Method: Serialize + Debug,
127    {
128        #[cfg(feature = "std")]
129        {
130            self.write.send_call(call, fds).await
131        }
132        #[cfg(not(feature = "std"))]
133        {
134            self.write.send_call(call).await
135        }
136    }
137
138    /// Receives a method call reply.
139    ///
140    /// Convenience wrapper around [`ReadConnection::receive_reply`].
141    pub async fn receive_reply<'r, ReplyParams, ReplyError>(
142        &'r mut self,
143    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
144    where
145        ReplyParams: Deserialize<'r> + Debug,
146        ReplyError: Deserialize<'r> + Debug,
147    {
148        self.read.receive_reply().await
149    }
150
151    /// Call a method and receive a reply.
152    ///
153    /// This is a convenience method that combines [`Connection::send_call`] and
154    /// [`Connection::receive_reply`].
155    pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
156        &'r mut self,
157        call: &Call<Method>,
158        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
159    ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
160    where
161        Method: Serialize + Debug,
162        ReplyParams: Deserialize<'r> + Debug,
163        ReplyError: Deserialize<'r> + Debug,
164    {
165        #[cfg(feature = "std")]
166        self.send_call(call, fds).await?;
167        #[cfg(not(feature = "std"))]
168        self.send_call(call).await?;
169
170        self.receive_reply().await
171    }
172
173    /// Receive a method call over the socket.
174    ///
175    /// Convenience wrapper around [`ReadConnection::receive_call`].
176    pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
177    where
178        Method: Deserialize<'m> + Debug,
179    {
180        self.read.receive_call().await
181    }
182
183    /// Send a reply over the socket.
184    ///
185    /// Convenience wrapper around [`WriteConnection::send_reply`].
186    pub async fn send_reply<ReplyParams>(
187        &mut self,
188        reply: &Reply<ReplyParams>,
189        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
190    ) -> Result<()>
191    where
192        ReplyParams: Serialize + Debug,
193    {
194        #[cfg(feature = "std")]
195        {
196            self.write.send_reply(reply, fds).await
197        }
198        #[cfg(not(feature = "std"))]
199        {
200            self.write.send_reply(reply).await
201        }
202    }
203
204    /// Send an error reply over the socket.
205    ///
206    /// Convenience wrapper around [`WriteConnection::send_error`].
207    pub async fn send_error<ReplyError>(
208        &mut self,
209        error: &ReplyError,
210        #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
211    ) -> Result<()>
212    where
213        ReplyError: Serialize + Debug,
214    {
215        #[cfg(feature = "std")]
216        {
217            self.write.send_error(error, fds).await
218        }
219        #[cfg(not(feature = "std"))]
220        {
221            self.write.send_error(error).await
222        }
223    }
224
225    /// Enqueue a call to the server.
226    ///
227    /// Convenience wrapper around [`WriteConnection::enqueue_call`].
228    pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
229    where
230        Method: Serialize + Debug,
231    {
232        #[cfg(feature = "std")]
233        {
234            self.write.enqueue_call(method, vec![])
235        }
236        #[cfg(not(feature = "std"))]
237        {
238            self.write.enqueue_call(method)
239        }
240    }
241
242    /// Flush the connection.
243    ///
244    /// Convenience wrapper around [`WriteConnection::flush`].
245    pub async fn flush(&mut self) -> Result<()> {
246        self.write.flush().await
247    }
248
249    /// Start a chain of method calls.
250    ///
251    /// This allows batching multiple calls together and sending them in a single write operation.
252    ///
253    /// # Examples
254    ///
255    /// ## Basic Usage with Sequential Access
256    ///
257    /// ```no_run
258    /// use zlink_core::{Connection, Call, reply};
259    /// use serde::{Serialize, Deserialize};
260    /// use serde_prefix_all::prefix_all;
261    /// use futures_util::{pin_mut, stream::StreamExt};
262    ///
263    /// # async fn example() -> zlink_core::Result<()> {
264    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
265    ///
266    /// #[prefix_all("org.example.")]
267    /// #[derive(Debug, Serialize, Deserialize)]
268    /// #[serde(tag = "method", content = "parameters")]
269    /// enum Methods {
270    ///     GetUser { id: u32 },
271    ///     GetProject { id: u32 },
272    /// }
273    ///
274    /// #[derive(Debug, Deserialize)]
275    /// struct User { name: String }
276    ///
277    /// #[derive(Debug, Deserialize)]
278    /// struct Project { title: String }
279    ///
280    /// #[derive(Debug, zlink_core::ReplyError)]
281    /// #[zlink(
282    ///     interface = "org.example",
283    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
284    ///     crate = "zlink_core",
285    /// )]
286    /// enum ApiError {
287    ///     UserNotFound { code: i32 },
288    ///     ProjectNotFound { code: i32 },
289    /// }
290    ///
291    /// let get_user = Call::new(Methods::GetUser { id: 1 });
292    /// let get_project = Call::new(Methods::GetProject { id: 2 });
293    ///
294    /// // Chain calls and send them in a batch
295    /// # #[cfg(feature = "std")]
296    /// let replies = conn
297    ///     .chain_call::<Methods, User, ApiError>(&get_user, vec![])?
298    ///     .append(&get_project, vec![])?
299    ///     .send().await?;
300    /// # #[cfg(not(feature = "std"))]
301    /// # let replies = conn
302    /// #     .chain_call::<Methods, User, ApiError>(&get_user)?
303    /// #     .append(&get_project)?
304    /// #     .send().await?;
305    /// pin_mut!(replies);
306    ///
307    /// // Access replies sequentially - types are now fixed by the chain
308    /// # #[cfg(feature = "std")]
309    /// # {
310    /// let (user_reply, _fds) = replies.next().await.unwrap()?;
311    /// let (project_reply, _fds) = replies.next().await.unwrap()?;
312    ///
313    /// match user_reply {
314    ///     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
315    ///     Err(error) => println!("User error: {:?}", error),
316    /// }
317    /// # }
318    /// # #[cfg(not(feature = "std"))]
319    /// # {
320    /// # let user_reply = replies.next().await.unwrap()?;
321    /// # let _project_reply = replies.next().await.unwrap()?;
322    /// #
323    /// # match user_reply {
324    /// #     Ok(user) => println!("User: {}", user.parameters().unwrap().name),
325    /// #     Err(error) => println!("User error: {:?}", error),
326    /// # }
327    /// # }
328    /// # Ok(())
329    /// # }
330    /// ```
331    ///
332    /// ## Arbitrary Number of Calls
333    ///
334    /// ```no_run
335    /// # use zlink_core::{Connection, Call, reply};
336    /// # use serde::{Serialize, Deserialize};
337    /// # use futures_util::{pin_mut, stream::StreamExt};
338    /// # use serde_prefix_all::prefix_all;
339    /// # async fn example() -> zlink_core::Result<()> {
340    /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
341    /// # #[prefix_all("org.example.")]
342    /// # #[derive(Debug, Serialize, Deserialize)]
343    /// # #[serde(tag = "method", content = "parameters")]
344    /// # enum Methods {
345    /// #     GetUser { id: u32 },
346    /// # }
347    /// # #[derive(Debug, Deserialize)]
348    /// # struct User { name: String }
349    /// # #[derive(Debug, zlink_core::ReplyError)]
350    /// #[zlink(
351    ///     interface = "org.example",
352    ///     // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
353    ///     crate = "zlink_core",
354    /// )]
355    /// # enum ApiError {
356    /// #     UserNotFound { code: i32 },
357    /// #     ProjectNotFound { code: i32 },
358    /// # }
359    /// # let get_user = Call::new(Methods::GetUser { id: 1 });
360    ///
361    /// // Chain many calls (no upper limit)
362    /// # #[cfg(feature = "std")]
363    /// let mut chain = conn.chain_call::<Methods, User, ApiError>(&get_user, vec![])?;
364    /// # #[cfg(not(feature = "std"))]
365    /// # let mut chain = conn.chain_call::<Methods, User, ApiError>(&get_user)?;
366    /// # #[cfg(feature = "std")]
367    /// for i in 2..100 {
368    ///     chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
369    /// }
370    /// # #[cfg(not(feature = "std"))]
371    /// # for i in 2..100 {
372    /// #     chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
373    /// # }
374    ///
375    /// let replies = chain.send().await?;
376    /// pin_mut!(replies);
377    ///
378    /// // Process all replies sequentially - types are fixed by the chain
379    /// # #[cfg(feature = "std")]
380    /// while let Some(result) = replies.next().await {
381    ///     let (user_reply, _fds) = result?;
382    ///     // Handle each reply...
383    ///     match user_reply {
384    ///         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
385    ///         Err(error) => println!("Error: {:?}", error),
386    ///     }
387    /// }
388    /// # #[cfg(not(feature = "std"))]
389    /// # while let Some(result) = replies.next().await {
390    /// #     let user_reply = result?;
391    /// #     // Handle each reply...
392    /// #     match user_reply {
393    /// #         Ok(user) => println!("User: {}", user.parameters().unwrap().name),
394    /// #         Err(error) => println!("Error: {:?}", error),
395    /// #     }
396    /// # }
397    /// # Ok(())
398    /// # }
399    /// ```
400    ///
401    /// # Performance Benefits
402    ///
403    /// Instead of multiple write operations, the chain sends all calls in a single
404    /// write operation, reducing context switching and therefore minimizing latency.
405    pub fn chain_call<'c, Method, ReplyParams, ReplyError>(
406        &'c mut self,
407        call: &Call<Method>,
408        #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
409    ) -> Result<Chain<'c, S, ReplyParams, ReplyError>>
410    where
411        Method: Serialize + Debug,
412        ReplyParams: Deserialize<'c> + Debug,
413        ReplyError: Deserialize<'c> + Debug,
414    {
415        Chain::new(
416            self,
417            call,
418            #[cfg(feature = "std")]
419            fds,
420        )
421    }
422
423    /// Get the peer credentials.
424    ///
425    /// This method caches the credentials on the first call.
426    #[cfg(feature = "std")]
427    pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
428    where
429        S::ReadHalf: socket::FetchPeerCredentials,
430    {
431        if self.credentials.is_none() {
432            let creds = self.read.read_half().fetch_peer_credentials().await?;
433            self.credentials = Some(std::sync::Arc::new(creds));
434        }
435
436        // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
437        // method doesn't error out.
438        Ok(self.credentials.as_ref().unwrap())
439    }
440}
441
442impl<S> From<S> for Connection<S>
443where
444    S: Socket,
445{
446    fn from(socket: S) -> Self {
447        Self::new(socket)
448    }
449}
450
451pub(crate) const BUFFER_SIZE: usize = 256;
452const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
453
454static NEXT_ID: AtomicUsize = AtomicUsize::new(0);