zlink_core/connection/mod.rs
1//! Contains connection related API.
2//!
3//! The [`Connection`] type provides a low-level API for sending and receiving Varlink messages.
4//! For most use cases, you'll want to use the higher-level [`proxy`] and [`service`] attribute
5//! macros instead, which generate type-safe client and server code respectively.
6//!
7//! # Client Usage with `proxy` Macro
8//!
9//! The [`proxy`] macro generates methods on `Connection<S>` for calling remote service methods:
10//!
11//! ```
12//! #[zlink_core::proxy(
13//! interface = "org.example.Calculator",
14//! // Not needed in the real code because you'll use `proxy` through `zlink` crate.
15//! crate = "zlink_core",
16//! )]
17//! trait CalculatorProxy {
18//! async fn add(&mut self, a: f64, b: f64) -> zlink_core::Result<Result<f64, CalcError>>;
19//! }
20//!
21//! #[derive(Debug, zlink_core::ReplyError)]
22//! #[zlink(
23//! interface = "org.example.Calculator",
24//! // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
25//! crate = "zlink_core",
26//! )]
27//! enum CalcError {}
28//! ```
29//!
30//! # Server Usage with `service` Macro
31//!
32//! The [`service`] macro generates the [`Service`] trait implementation. See the [`service`] macro
33//! documentation for details and examples.
34//!
35//! # Low-Level API
36//!
37//! For advanced use cases that require more control, the [`Connection`] type provides direct access
38//! to message sending and receiving via methods like [`Connection::send_call`],
39//! [`Connection::receive_reply`], and [`Connection::chain_call`] for pipelining.
40//!
41//! [`proxy`]: macro@crate::proxy
42//! [`service`]: macro@crate::service
43//! [`Service`]: crate::service::Service
44
45#[cfg(feature = "std")]
46mod credentials;
47mod read_connection;
48#[cfg(feature = "std")]
49pub use credentials::Credentials;
50pub use read_connection::ReadConnection;
51#[cfg(feature = "std")]
52pub use rustix::{process::Gid, process::Pid, process::Uid};
53pub mod chain;
54pub mod socket;
55#[cfg(test)]
56mod tests;
57mod write_connection;
58use crate::{
59 Call, Result,
60 reply::{self, Reply},
61};
62#[cfg(feature = "std")]
63use alloc::vec;
64pub use chain::Chain;
65use core::{fmt::Debug, sync::atomic::AtomicUsize};
66#[cfg(feature = "std")]
67use socket::FetchPeerCredentials;
68pub use write_connection::WriteConnection;
69
70use serde::{Deserialize, Serialize};
71pub use socket::Socket;
72
73// Type alias for receive methods - std returns FDs, no_std doesn't
74#[cfg(feature = "std")]
75type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
76#[cfg(not(feature = "std"))]
77type RecvResult<T> = T;
78
79/// A connection.
80///
81/// The low-level API to send and receive messages.
82///
83/// Each connection gets a unique identifier when created that can be queried using
84/// [`Connection::id`]. This ID is shared between the read and write halves of the connection. It
85/// can be used to associate the read and write halves of the same connection.
86///
87/// # Cancel safety
88///
89/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
90/// documentation.
91#[derive(Debug)]
92pub struct Connection<S: Socket> {
93 read: ReadConnection<S::ReadHalf>,
94 write: WriteConnection<S::WriteHalf>,
95 #[cfg(feature = "std")]
96 credentials: Option<std::sync::Arc<Credentials>>,
97}
98
99impl<S> Connection<S>
100where
101 S: Socket,
102{
103 /// Create a new connection.
104 pub fn new(socket: S) -> Self {
105 let (read, write) = socket.split();
106 let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
107 Self {
108 read: ReadConnection::new(read, id),
109 write: WriteConnection::new(write, id),
110 #[cfg(feature = "std")]
111 credentials: None,
112 }
113 }
114
115 /// The reference to the read half of the connection.
116 pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
117 &self.read
118 }
119
120 /// The mutable reference to the read half of the connection.
121 pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
122 &mut self.read
123 }
124
125 /// The reference to the write half of the connection.
126 pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
127 &self.write
128 }
129
130 /// The mutable reference to the write half of the connection.
131 pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
132 &mut self.write
133 }
134
135 /// Split the connection into read and write halves.
136 ///
137 /// Note: This consumes any cached credentials. If you need the credentials after splitting,
138 /// call [`Connection::peer_credentials`] before splitting.
139 pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
140 (self.read, self.write)
141 }
142
143 /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
144 pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
145 Self {
146 read,
147 write,
148 #[cfg(feature = "std")]
149 credentials: None,
150 }
151 }
152
153 /// Release sent FDs that the read half has confirmed receiving via `recvmsg`.
154 /// See `WriteConnection::held_fds` for details on the macOS kernel issue.
155 #[cfg(all(feature = "std", target_os = "macos"))]
156 fn drain_held_fds(&mut self) {
157 let to_drain = self.read.fd_recvs;
158 for _ in 0..to_drain {
159 self.write.held_fds.pop_front();
160 }
161 self.read.fd_recvs -= to_drain;
162 }
163
164 /// The unique identifier of the connection.
165 pub fn id(&self) -> usize {
166 assert_eq!(self.read.id(), self.write.id());
167 self.read.id()
168 }
169
170 /// Sends a method call.
171 ///
172 /// Convenience wrapper around [`WriteConnection::send_call`].
173 pub async fn send_call<Method>(
174 &mut self,
175 call: &Call<Method>,
176 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
177 ) -> Result<()>
178 where
179 Method: Serialize + Debug,
180 {
181 #[cfg(feature = "std")]
182 {
183 self.write.send_call(call, fds).await
184 }
185 #[cfg(not(feature = "std"))]
186 {
187 self.write.send_call(call).await
188 }
189 }
190
191 /// Receives a method call reply.
192 ///
193 /// Convenience wrapper around [`ReadConnection::receive_reply`].
194 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
195 &'r mut self,
196 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
197 where
198 ReplyParams: Deserialize<'r> + Debug,
199 ReplyError: Deserialize<'r> + Debug,
200 {
201 self.read.receive_reply().await
202 }
203
204 /// Call a method and receive a reply.
205 ///
206 /// This is a convenience method that combines [`Connection::send_call`] and
207 /// [`Connection::receive_reply`].
208 pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
209 &'r mut self,
210 call: &Call<Method>,
211 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
212 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
213 where
214 Method: Serialize + Debug,
215 ReplyParams: Deserialize<'r> + Debug,
216 ReplyError: Deserialize<'r> + Debug,
217 {
218 #[cfg(feature = "std")]
219 self.send_call(call, fds).await?;
220 #[cfg(not(feature = "std"))]
221 self.send_call(call).await?;
222
223 self.receive_reply().await
224 }
225
226 /// Receive a method call over the socket.
227 ///
228 /// Convenience wrapper around [`ReadConnection::receive_call`].
229 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
230 where
231 Method: Deserialize<'m> + Debug,
232 {
233 self.read.receive_call().await
234 }
235
236 /// Send a reply over the socket.
237 ///
238 /// Convenience wrapper around [`WriteConnection::send_reply`].
239 pub async fn send_reply<ReplyParams>(
240 &mut self,
241 reply: &Reply<ReplyParams>,
242 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
243 ) -> Result<()>
244 where
245 ReplyParams: Serialize + Debug,
246 {
247 #[cfg(all(feature = "std", target_os = "macos"))]
248 self.drain_held_fds();
249 #[cfg(feature = "std")]
250 {
251 self.write.send_reply(reply, fds).await
252 }
253 #[cfg(not(feature = "std"))]
254 {
255 self.write.send_reply(reply).await
256 }
257 }
258
259 /// Send an error reply over the socket.
260 ///
261 /// Convenience wrapper around [`WriteConnection::send_error`].
262 pub async fn send_error<ReplyError>(
263 &mut self,
264 error: &ReplyError,
265 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
266 ) -> Result<()>
267 where
268 ReplyError: Serialize + Debug,
269 {
270 #[cfg(all(feature = "std", target_os = "macos"))]
271 self.drain_held_fds();
272 #[cfg(feature = "std")]
273 {
274 self.write.send_error(error, fds).await
275 }
276 #[cfg(not(feature = "std"))]
277 {
278 self.write.send_error(error).await
279 }
280 }
281
282 /// Enqueue a call to the server.
283 ///
284 /// Convenience wrapper around [`WriteConnection::enqueue_call`].
285 pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
286 where
287 Method: Serialize + Debug,
288 {
289 #[cfg(feature = "std")]
290 {
291 self.write.enqueue_call(method, vec![])
292 }
293 #[cfg(not(feature = "std"))]
294 {
295 self.write.enqueue_call(method)
296 }
297 }
298
299 /// Flush the connection.
300 ///
301 /// Convenience wrapper around [`WriteConnection::flush`].
302 pub async fn flush(&mut self) -> Result<()> {
303 self.write.flush().await
304 }
305
306 /// Start a chain of method calls.
307 ///
308 /// This allows batching multiple calls together and sending them in a single write operation.
309 ///
310 /// # Examples
311 ///
312 /// ## Basic Usage with Sequential Access
313 ///
314 /// ```no_run
315 /// use zlink_core::{Connection, Call, reply};
316 /// use serde::{Serialize, Deserialize};
317 /// use serde_prefix_all::prefix_all;
318 /// use futures_util::{pin_mut, stream::StreamExt};
319 ///
320 /// # async fn example() -> zlink_core::Result<()> {
321 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
322 ///
323 /// #[prefix_all("org.example.")]
324 /// #[derive(Debug, Serialize, Deserialize)]
325 /// #[serde(tag = "method", content = "parameters")]
326 /// enum Methods {
327 /// GetUser { id: u32 },
328 /// GetProject { id: u32 },
329 /// }
330 ///
331 /// #[derive(Debug, Deserialize)]
332 /// struct User { name: String }
333 ///
334 /// #[derive(Debug, Deserialize)]
335 /// struct Project { title: String }
336 ///
337 /// #[derive(Debug, zlink_core::ReplyError)]
338 /// #[zlink(
339 /// interface = "org.example",
340 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
341 /// crate = "zlink_core",
342 /// )]
343 /// enum ApiError {
344 /// UserNotFound { code: i32 },
345 /// ProjectNotFound { code: i32 },
346 /// }
347 ///
348 /// let get_user = Call::new(Methods::GetUser { id: 1 });
349 /// let get_project = Call::new(Methods::GetProject { id: 2 });
350 ///
351 /// // Chain calls and send them in a batch
352 /// # #[cfg(feature = "std")]
353 /// let replies = conn
354 /// .chain_call::<Methods>(&get_user, vec![])?
355 /// .append(&get_project, vec![])?
356 /// .send::<User, ApiError>().await?;
357 /// # #[cfg(not(feature = "std"))]
358 /// # let replies = conn
359 /// # .chain_call::<Methods>(&get_user)?
360 /// # .append(&get_project)?
361 /// # .send::<User, ApiError>().await?;
362 /// pin_mut!(replies);
363 ///
364 /// // Access replies sequentially.
365 /// # #[cfg(feature = "std")]
366 /// # {
367 /// let (user_reply, _fds) = replies.next().await.unwrap()?;
368 /// let (project_reply, _fds) = replies.next().await.unwrap()?;
369 ///
370 /// match user_reply {
371 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
372 /// Err(error) => println!("User error: {:?}", error),
373 /// }
374 /// # }
375 /// # #[cfg(not(feature = "std"))]
376 /// # {
377 /// # let user_reply = replies.next().await.unwrap()?;
378 /// # let _project_reply = replies.next().await.unwrap()?;
379 /// #
380 /// # match user_reply {
381 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
382 /// # Err(error) => println!("User error: {:?}", error),
383 /// # }
384 /// # }
385 /// # Ok(())
386 /// # }
387 /// ```
388 ///
389 /// ## Arbitrary Number of Calls
390 ///
391 /// ```no_run
392 /// # use zlink_core::{Connection, Call, reply};
393 /// # use serde::{Serialize, Deserialize};
394 /// # use futures_util::{pin_mut, stream::StreamExt};
395 /// # use serde_prefix_all::prefix_all;
396 /// # async fn example() -> zlink_core::Result<()> {
397 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
398 /// # #[prefix_all("org.example.")]
399 /// # #[derive(Debug, Serialize, Deserialize)]
400 /// # #[serde(tag = "method", content = "parameters")]
401 /// # enum Methods {
402 /// # GetUser { id: u32 },
403 /// # }
404 /// # #[derive(Debug, Deserialize)]
405 /// # struct User { name: String }
406 /// # #[derive(Debug, zlink_core::ReplyError)]
407 /// #[zlink(
408 /// interface = "org.example",
409 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
410 /// crate = "zlink_core",
411 /// )]
412 /// # enum ApiError {
413 /// # UserNotFound { code: i32 },
414 /// # ProjectNotFound { code: i32 },
415 /// # }
416 /// # let get_user = Call::new(Methods::GetUser { id: 1 });
417 ///
418 /// // Chain many calls (no upper limit)
419 /// # #[cfg(feature = "std")]
420 /// let mut chain = conn.chain_call::<Methods>(&get_user, vec![])?;
421 /// # #[cfg(not(feature = "std"))]
422 /// # let mut chain = conn.chain_call::<Methods>(&get_user)?;
423 /// # #[cfg(feature = "std")]
424 /// for i in 2..100 {
425 /// chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
426 /// }
427 /// # #[cfg(not(feature = "std"))]
428 /// # for i in 2..100 {
429 /// # chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
430 /// # }
431 ///
432 /// let replies = chain.send::<User, ApiError>().await?;
433 /// pin_mut!(replies);
434 ///
435 /// // Process all replies sequentially.
436 /// # #[cfg(feature = "std")]
437 /// while let Some(result) = replies.next().await {
438 /// let (user_reply, _fds) = result?;
439 /// // Handle each reply...
440 /// match user_reply {
441 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
442 /// Err(error) => println!("Error: {:?}", error),
443 /// }
444 /// }
445 /// # #[cfg(not(feature = "std"))]
446 /// # while let Some(result) = replies.next().await {
447 /// # let user_reply = result?;
448 /// # // Handle each reply...
449 /// # match user_reply {
450 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
451 /// # Err(error) => println!("Error: {:?}", error),
452 /// # }
453 /// # }
454 /// # Ok(())
455 /// # }
456 /// ```
457 ///
458 /// # Performance Benefits
459 ///
460 /// Instead of multiple write operations, the chain sends all calls in a single
461 /// write operation, reducing context switching and therefore minimizing latency.
462 pub fn chain_call<'c, Method>(
463 &'c mut self,
464 call: &Call<Method>,
465 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
466 ) -> Result<Chain<'c, S>>
467 where
468 Method: Serialize + Debug,
469 {
470 Chain::new(
471 self,
472 call,
473 #[cfg(feature = "std")]
474 fds,
475 )
476 }
477
478 /// Create a chain from an iterator of method calls.
479 ///
480 /// This allows creating a chain from any iterator yielding method types or calls. Each item
481 /// is automatically converted to a [`Call`] via [`Into<Call<Method>>`]. Unlike
482 /// [`Connection::chain_call`], this method allows building chains from dynamically-sized
483 /// collections.
484 ///
485 /// # Errors
486 ///
487 /// Returns [`Error::EmptyChain`] if the iterator is empty.
488 ///
489 /// # Examples
490 ///
491 /// ```no_run
492 /// use zlink_core::Connection;
493 /// use serde::{Serialize, Deserialize};
494 /// use serde_prefix_all::prefix_all;
495 /// use futures_util::{pin_mut, stream::StreamExt};
496 ///
497 /// # async fn example() -> zlink_core::Result<()> {
498 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
499 ///
500 /// #[prefix_all("org.example.")]
501 /// #[derive(Debug, Serialize, Deserialize)]
502 /// #[serde(tag = "method", content = "parameters")]
503 /// enum Methods {
504 /// GetUser { id: u32 },
505 /// }
506 ///
507 /// #[derive(Debug, Deserialize)]
508 /// struct User { name: String }
509 ///
510 /// #[derive(Debug, zlink_core::ReplyError)]
511 /// #[zlink(interface = "org.example", crate = "zlink_core")]
512 /// enum ApiError {
513 /// UserNotFound { code: i32 },
514 /// }
515 ///
516 /// let user_ids = [1, 2, 3, 4, 5];
517 /// let replies = conn
518 /// .chain_from_iter::<Methods, _, _>(
519 /// user_ids.iter().map(|&id| Methods::GetUser { id })
520 /// )?
521 /// .send::<User, ApiError>()
522 /// .await?;
523 /// pin_mut!(replies);
524 ///
525 /// # #[cfg(feature = "std")]
526 /// while let Some(result) = replies.next().await {
527 /// let (user_reply, _fds) = result?;
528 /// // Handle each reply...
529 /// }
530 /// # Ok(())
531 /// # }
532 /// ```
533 ///
534 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
535 pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
536 &'c mut self,
537 calls: MethodCalls,
538 ) -> Result<Chain<'c, S>>
539 where
540 Method: Serialize + Debug,
541 MethodCall: Into<Call<Method>>,
542 MethodCalls: IntoIterator<Item = MethodCall>,
543 {
544 let mut iter = calls.into_iter();
545 let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
546
547 #[cfg(feature = "std")]
548 let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
549 #[cfg(not(feature = "std"))]
550 let mut chain = Chain::new(self, &first)?;
551
552 for call in iter {
553 let call: Call<Method> = call.into();
554 #[cfg(feature = "std")]
555 {
556 chain = chain.append(&call, alloc::vec::Vec::new())?;
557 }
558 #[cfg(not(feature = "std"))]
559 {
560 chain = chain.append(&call)?;
561 }
562 }
563
564 Ok(chain)
565 }
566
567 /// Create a chain from an iterator of method calls with file descriptors.
568 ///
569 /// Similar to [`Connection::chain_from_iter`], but allows passing file descriptors with each
570 /// call. Each item in the iterator is a tuple of a method type (or [`Call`]) and its
571 /// associated file descriptors.
572 ///
573 /// # Errors
574 ///
575 /// Returns [`Error::EmptyChain`] if the iterator is empty.
576 ///
577 /// # Examples
578 ///
579 /// ```no_run
580 /// use zlink_core::Connection;
581 /// use serde::{Serialize, Deserialize};
582 /// use serde_prefix_all::prefix_all;
583 /// use std::os::fd::OwnedFd;
584 ///
585 /// # async fn example() -> zlink_core::Result<()> {
586 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
587 ///
588 /// #[prefix_all("org.example.")]
589 /// #[derive(Debug, Serialize, Deserialize)]
590 /// #[serde(tag = "method", content = "parameters")]
591 /// enum Methods {
592 /// SendFile { name: String },
593 /// }
594 ///
595 /// #[derive(Debug, Deserialize)]
596 /// struct FileResult { success: bool }
597 ///
598 /// #[derive(Debug, zlink_core::ReplyError)]
599 /// #[zlink(interface = "org.example", crate = "zlink_core")]
600 /// enum ApiError {
601 /// SendFailed { reason: String },
602 /// }
603 ///
604 /// let calls_with_fds: Vec<(Methods, Vec<OwnedFd>)> = vec![
605 /// (Methods::SendFile { name: "file1.txt".into() }, vec![/* fd1 */]),
606 /// (Methods::SendFile { name: "file2.txt".into() }, vec![/* fd2 */]),
607 /// ];
608 ///
609 /// let replies = conn
610 /// .chain_from_iter_with_fds::<Methods, _, _>(calls_with_fds)?
611 /// .send::<FileResult, ApiError>()
612 /// .await?;
613 /// # Ok(())
614 /// # }
615 /// ```
616 ///
617 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
618 #[cfg(feature = "std")]
619 pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
620 &'c mut self,
621 calls: MethodCalls,
622 ) -> Result<Chain<'c, S>>
623 where
624 Method: Serialize + Debug,
625 MethodCall: Into<Call<Method>>,
626 MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
627 {
628 let mut iter = calls.into_iter();
629 let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
630 let first: Call<Method> = first.into();
631 let mut chain = Chain::new(self, &first, first_fds)?;
632
633 for (call, fds) in iter {
634 let call: Call<Method> = call.into();
635 chain = chain.append(&call, fds)?;
636 }
637
638 Ok(chain)
639 }
640
641 /// Get the peer credentials.
642 ///
643 /// This method caches the credentials on the first call.
644 #[cfg(feature = "std")]
645 pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
646 where
647 S::ReadHalf: socket::FetchPeerCredentials,
648 {
649 if self.credentials.is_none() {
650 let creds = self.read.read_half().fetch_peer_credentials().await?;
651 self.credentials = Some(std::sync::Arc::new(creds));
652 }
653
654 // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
655 // method doesn't error out.
656 Ok(self.credentials.as_ref().unwrap())
657 }
658}
659
660impl<S> From<S> for Connection<S>
661where
662 S: Socket,
663{
664 fn from(socket: S) -> Self {
665 Self::new(socket)
666 }
667}
668
669pub(crate) const BUFFER_SIZE: usize = 256;
670const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
671
672static NEXT_ID: AtomicUsize = AtomicUsize::new(0);