zlink_core/connection/mod.rs
1//! Contains connection related API.
2//!
3//! The [`Connection`] type provides a low-level API for sending and receiving Varlink messages.
4//! For most use cases, you'll want to use the higher-level [`proxy`] and [`service`] attribute
5//! macros instead, which generate type-safe client and server code respectively.
6//!
7//! # Client Usage with `proxy` Macro
8//!
9//! The [`proxy`] macro generates methods on `Connection<S>` for calling remote service methods:
10//!
11//! ```
12//! #[zlink_core::proxy(
13//! interface = "org.example.Calculator",
14//! // Not needed in the real code because you'll use `proxy` through `zlink` crate.
15//! crate = "zlink_core",
16//! )]
17//! trait CalculatorProxy {
18//! async fn add(&mut self, a: f64, b: f64) -> zlink_core::Result<Result<f64, CalcError>>;
19//! }
20//!
21//! #[derive(Debug, zlink_core::ReplyError)]
22//! #[zlink(
23//! interface = "org.example.Calculator",
24//! // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
25//! crate = "zlink_core",
26//! )]
27//! enum CalcError {}
28//! ```
29//!
30//! # Server Usage with `service` Macro
31//!
32//! The [`service`] macro generates the [`Service`] trait implementation. See the [`service`] macro
33//! documentation for details and examples.
34//!
35//! # Low-Level API
36//!
37//! For advanced use cases that require more control, the [`Connection`] type provides direct access
38//! to message sending and receiving via methods like [`Connection::send_call`],
39//! [`Connection::receive_reply`], and [`Connection::chain_call`] for pipelining.
40//!
41//! [`proxy`]: macro@crate::proxy
42//! [`service`]: macro@crate::service
43//! [`Service`]: crate::service::Service
44
45#[cfg(feature = "std")]
46mod credentials;
47mod read_connection;
48#[cfg(feature = "std")]
49pub use credentials::Credentials;
50pub use read_connection::ReadConnection;
51#[cfg(feature = "std")]
52pub use rustix::{process::Pid, process::Uid};
53pub mod chain;
54pub mod socket;
55#[cfg(test)]
56mod tests;
57mod write_connection;
58use crate::{
59 reply::{self, Reply},
60 Call, Result,
61};
62#[cfg(feature = "std")]
63use alloc::vec;
64pub use chain::Chain;
65use core::{fmt::Debug, sync::atomic::AtomicUsize};
66#[cfg(feature = "std")]
67use socket::FetchPeerCredentials;
68pub use write_connection::WriteConnection;
69
70use serde::{Deserialize, Serialize};
71pub use socket::Socket;
72
73// Type alias for receive methods - std returns FDs, no_std doesn't
74#[cfg(feature = "std")]
75type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
76#[cfg(not(feature = "std"))]
77type RecvResult<T> = T;
78
79/// A connection.
80///
81/// The low-level API to send and receive messages.
82///
83/// Each connection gets a unique identifier when created that can be queried using
84/// [`Connection::id`]. This ID is shared between the read and write halves of the connection. It
85/// can be used to associate the read and write halves of the same connection.
86///
87/// # Cancel safety
88///
89/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
90/// documentation.
91#[derive(Debug)]
92pub struct Connection<S: Socket> {
93 read: ReadConnection<S::ReadHalf>,
94 write: WriteConnection<S::WriteHalf>,
95 #[cfg(feature = "std")]
96 credentials: Option<std::sync::Arc<Credentials>>,
97}
98
99impl<S> Connection<S>
100where
101 S: Socket,
102{
103 /// Create a new connection.
104 pub fn new(socket: S) -> Self {
105 let (read, write) = socket.split();
106 let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
107 Self {
108 read: ReadConnection::new(read, id),
109 write: WriteConnection::new(write, id),
110 #[cfg(feature = "std")]
111 credentials: None,
112 }
113 }
114
115 /// The reference to the read half of the connection.
116 pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
117 &self.read
118 }
119
120 /// The mutable reference to the read half of the connection.
121 pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
122 &mut self.read
123 }
124
125 /// The reference to the write half of the connection.
126 pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
127 &self.write
128 }
129
130 /// The mutable reference to the write half of the connection.
131 pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
132 &mut self.write
133 }
134
135 /// Split the connection into read and write halves.
136 ///
137 /// Note: This consumes any cached credentials. If you need the credentials after splitting,
138 /// call [`Connection::peer_credentials`] before splitting.
139 pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
140 (self.read, self.write)
141 }
142
143 /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
144 pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
145 Self {
146 read,
147 write,
148 #[cfg(feature = "std")]
149 credentials: None,
150 }
151 }
152
153 /// The unique identifier of the connection.
154 pub fn id(&self) -> usize {
155 assert_eq!(self.read.id(), self.write.id());
156 self.read.id()
157 }
158
159 /// Sends a method call.
160 ///
161 /// Convenience wrapper around [`WriteConnection::send_call`].
162 pub async fn send_call<Method>(
163 &mut self,
164 call: &Call<Method>,
165 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
166 ) -> Result<()>
167 where
168 Method: Serialize + Debug,
169 {
170 #[cfg(feature = "std")]
171 {
172 self.write.send_call(call, fds).await
173 }
174 #[cfg(not(feature = "std"))]
175 {
176 self.write.send_call(call).await
177 }
178 }
179
180 /// Receives a method call reply.
181 ///
182 /// Convenience wrapper around [`ReadConnection::receive_reply`].
183 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
184 &'r mut self,
185 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
186 where
187 ReplyParams: Deserialize<'r> + Debug,
188 ReplyError: Deserialize<'r> + Debug,
189 {
190 self.read.receive_reply().await
191 }
192
193 /// Call a method and receive a reply.
194 ///
195 /// This is a convenience method that combines [`Connection::send_call`] and
196 /// [`Connection::receive_reply`].
197 pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
198 &'r mut self,
199 call: &Call<Method>,
200 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
201 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
202 where
203 Method: Serialize + Debug,
204 ReplyParams: Deserialize<'r> + Debug,
205 ReplyError: Deserialize<'r> + Debug,
206 {
207 #[cfg(feature = "std")]
208 self.send_call(call, fds).await?;
209 #[cfg(not(feature = "std"))]
210 self.send_call(call).await?;
211
212 self.receive_reply().await
213 }
214
215 /// Receive a method call over the socket.
216 ///
217 /// Convenience wrapper around [`ReadConnection::receive_call`].
218 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
219 where
220 Method: Deserialize<'m> + Debug,
221 {
222 self.read.receive_call().await
223 }
224
225 /// Send a reply over the socket.
226 ///
227 /// Convenience wrapper around [`WriteConnection::send_reply`].
228 pub async fn send_reply<ReplyParams>(
229 &mut self,
230 reply: &Reply<ReplyParams>,
231 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
232 ) -> Result<()>
233 where
234 ReplyParams: Serialize + Debug,
235 {
236 #[cfg(feature = "std")]
237 {
238 self.write.send_reply(reply, fds).await
239 }
240 #[cfg(not(feature = "std"))]
241 {
242 self.write.send_reply(reply).await
243 }
244 }
245
246 /// Send an error reply over the socket.
247 ///
248 /// Convenience wrapper around [`WriteConnection::send_error`].
249 pub async fn send_error<ReplyError>(
250 &mut self,
251 error: &ReplyError,
252 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
253 ) -> Result<()>
254 where
255 ReplyError: Serialize + Debug,
256 {
257 #[cfg(feature = "std")]
258 {
259 self.write.send_error(error, fds).await
260 }
261 #[cfg(not(feature = "std"))]
262 {
263 self.write.send_error(error).await
264 }
265 }
266
267 /// Enqueue a call to the server.
268 ///
269 /// Convenience wrapper around [`WriteConnection::enqueue_call`].
270 pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
271 where
272 Method: Serialize + Debug,
273 {
274 #[cfg(feature = "std")]
275 {
276 self.write.enqueue_call(method, vec![])
277 }
278 #[cfg(not(feature = "std"))]
279 {
280 self.write.enqueue_call(method)
281 }
282 }
283
284 /// Flush the connection.
285 ///
286 /// Convenience wrapper around [`WriteConnection::flush`].
287 pub async fn flush(&mut self) -> Result<()> {
288 self.write.flush().await
289 }
290
291 /// Start a chain of method calls.
292 ///
293 /// This allows batching multiple calls together and sending them in a single write operation.
294 ///
295 /// # Examples
296 ///
297 /// ## Basic Usage with Sequential Access
298 ///
299 /// ```no_run
300 /// use zlink_core::{Connection, Call, reply};
301 /// use serde::{Serialize, Deserialize};
302 /// use serde_prefix_all::prefix_all;
303 /// use futures_util::{pin_mut, stream::StreamExt};
304 ///
305 /// # async fn example() -> zlink_core::Result<()> {
306 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
307 ///
308 /// #[prefix_all("org.example.")]
309 /// #[derive(Debug, Serialize, Deserialize)]
310 /// #[serde(tag = "method", content = "parameters")]
311 /// enum Methods {
312 /// GetUser { id: u32 },
313 /// GetProject { id: u32 },
314 /// }
315 ///
316 /// #[derive(Debug, Deserialize)]
317 /// struct User { name: String }
318 ///
319 /// #[derive(Debug, Deserialize)]
320 /// struct Project { title: String }
321 ///
322 /// #[derive(Debug, zlink_core::ReplyError)]
323 /// #[zlink(
324 /// interface = "org.example",
325 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
326 /// crate = "zlink_core",
327 /// )]
328 /// enum ApiError {
329 /// UserNotFound { code: i32 },
330 /// ProjectNotFound { code: i32 },
331 /// }
332 ///
333 /// let get_user = Call::new(Methods::GetUser { id: 1 });
334 /// let get_project = Call::new(Methods::GetProject { id: 2 });
335 ///
336 /// // Chain calls and send them in a batch
337 /// # #[cfg(feature = "std")]
338 /// let replies = conn
339 /// .chain_call::<Methods>(&get_user, vec![])?
340 /// .append(&get_project, vec![])?
341 /// .send::<User, ApiError>().await?;
342 /// # #[cfg(not(feature = "std"))]
343 /// # let replies = conn
344 /// # .chain_call::<Methods>(&get_user)?
345 /// # .append(&get_project)?
346 /// # .send::<User, ApiError>().await?;
347 /// pin_mut!(replies);
348 ///
349 /// // Access replies sequentially.
350 /// # #[cfg(feature = "std")]
351 /// # {
352 /// let (user_reply, _fds) = replies.next().await.unwrap()?;
353 /// let (project_reply, _fds) = replies.next().await.unwrap()?;
354 ///
355 /// match user_reply {
356 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
357 /// Err(error) => println!("User error: {:?}", error),
358 /// }
359 /// # }
360 /// # #[cfg(not(feature = "std"))]
361 /// # {
362 /// # let user_reply = replies.next().await.unwrap()?;
363 /// # let _project_reply = replies.next().await.unwrap()?;
364 /// #
365 /// # match user_reply {
366 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
367 /// # Err(error) => println!("User error: {:?}", error),
368 /// # }
369 /// # }
370 /// # Ok(())
371 /// # }
372 /// ```
373 ///
374 /// ## Arbitrary Number of Calls
375 ///
376 /// ```no_run
377 /// # use zlink_core::{Connection, Call, reply};
378 /// # use serde::{Serialize, Deserialize};
379 /// # use futures_util::{pin_mut, stream::StreamExt};
380 /// # use serde_prefix_all::prefix_all;
381 /// # async fn example() -> zlink_core::Result<()> {
382 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
383 /// # #[prefix_all("org.example.")]
384 /// # #[derive(Debug, Serialize, Deserialize)]
385 /// # #[serde(tag = "method", content = "parameters")]
386 /// # enum Methods {
387 /// # GetUser { id: u32 },
388 /// # }
389 /// # #[derive(Debug, Deserialize)]
390 /// # struct User { name: String }
391 /// # #[derive(Debug, zlink_core::ReplyError)]
392 /// #[zlink(
393 /// interface = "org.example",
394 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
395 /// crate = "zlink_core",
396 /// )]
397 /// # enum ApiError {
398 /// # UserNotFound { code: i32 },
399 /// # ProjectNotFound { code: i32 },
400 /// # }
401 /// # let get_user = Call::new(Methods::GetUser { id: 1 });
402 ///
403 /// // Chain many calls (no upper limit)
404 /// # #[cfg(feature = "std")]
405 /// let mut chain = conn.chain_call::<Methods>(&get_user, vec![])?;
406 /// # #[cfg(not(feature = "std"))]
407 /// # let mut chain = conn.chain_call::<Methods>(&get_user)?;
408 /// # #[cfg(feature = "std")]
409 /// for i in 2..100 {
410 /// chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
411 /// }
412 /// # #[cfg(not(feature = "std"))]
413 /// # for i in 2..100 {
414 /// # chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
415 /// # }
416 ///
417 /// let replies = chain.send::<User, ApiError>().await?;
418 /// pin_mut!(replies);
419 ///
420 /// // Process all replies sequentially.
421 /// # #[cfg(feature = "std")]
422 /// while let Some(result) = replies.next().await {
423 /// let (user_reply, _fds) = result?;
424 /// // Handle each reply...
425 /// match user_reply {
426 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
427 /// Err(error) => println!("Error: {:?}", error),
428 /// }
429 /// }
430 /// # #[cfg(not(feature = "std"))]
431 /// # while let Some(result) = replies.next().await {
432 /// # let user_reply = result?;
433 /// # // Handle each reply...
434 /// # match user_reply {
435 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
436 /// # Err(error) => println!("Error: {:?}", error),
437 /// # }
438 /// # }
439 /// # Ok(())
440 /// # }
441 /// ```
442 ///
443 /// # Performance Benefits
444 ///
445 /// Instead of multiple write operations, the chain sends all calls in a single
446 /// write operation, reducing context switching and therefore minimizing latency.
447 pub fn chain_call<'c, Method>(
448 &'c mut self,
449 call: &Call<Method>,
450 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
451 ) -> Result<Chain<'c, S>>
452 where
453 Method: Serialize + Debug,
454 {
455 Chain::new(
456 self,
457 call,
458 #[cfg(feature = "std")]
459 fds,
460 )
461 }
462
463 /// Create a chain from an iterator of method calls.
464 ///
465 /// This allows creating a chain from any iterator yielding method types or calls. Each item
466 /// is automatically converted to a [`Call`] via [`Into<Call<Method>>`]. Unlike
467 /// [`Connection::chain_call`], this method allows building chains from dynamically-sized
468 /// collections.
469 ///
470 /// # Errors
471 ///
472 /// Returns [`Error::EmptyChain`] if the iterator is empty.
473 ///
474 /// # Examples
475 ///
476 /// ```no_run
477 /// use zlink_core::Connection;
478 /// use serde::{Serialize, Deserialize};
479 /// use serde_prefix_all::prefix_all;
480 /// use futures_util::{pin_mut, stream::StreamExt};
481 ///
482 /// # async fn example() -> zlink_core::Result<()> {
483 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
484 ///
485 /// #[prefix_all("org.example.")]
486 /// #[derive(Debug, Serialize, Deserialize)]
487 /// #[serde(tag = "method", content = "parameters")]
488 /// enum Methods {
489 /// GetUser { id: u32 },
490 /// }
491 ///
492 /// #[derive(Debug, Deserialize)]
493 /// struct User { name: String }
494 ///
495 /// #[derive(Debug, zlink_core::ReplyError)]
496 /// #[zlink(interface = "org.example", crate = "zlink_core")]
497 /// enum ApiError {
498 /// UserNotFound { code: i32 },
499 /// }
500 ///
501 /// let user_ids = [1, 2, 3, 4, 5];
502 /// let replies = conn
503 /// .chain_from_iter::<Methods, _, _>(
504 /// user_ids.iter().map(|&id| Methods::GetUser { id })
505 /// )?
506 /// .send::<User, ApiError>()
507 /// .await?;
508 /// pin_mut!(replies);
509 ///
510 /// # #[cfg(feature = "std")]
511 /// while let Some(result) = replies.next().await {
512 /// let (user_reply, _fds) = result?;
513 /// // Handle each reply...
514 /// }
515 /// # Ok(())
516 /// # }
517 /// ```
518 ///
519 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
520 pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
521 &'c mut self,
522 calls: MethodCalls,
523 ) -> Result<Chain<'c, S>>
524 where
525 Method: Serialize + Debug,
526 MethodCall: Into<Call<Method>>,
527 MethodCalls: IntoIterator<Item = MethodCall>,
528 {
529 let mut iter = calls.into_iter();
530 let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
531
532 #[cfg(feature = "std")]
533 let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
534 #[cfg(not(feature = "std"))]
535 let mut chain = Chain::new(self, &first)?;
536
537 for call in iter {
538 let call: Call<Method> = call.into();
539 #[cfg(feature = "std")]
540 {
541 chain = chain.append(&call, alloc::vec::Vec::new())?;
542 }
543 #[cfg(not(feature = "std"))]
544 {
545 chain = chain.append(&call)?;
546 }
547 }
548
549 Ok(chain)
550 }
551
552 /// Create a chain from an iterator of method calls with file descriptors.
553 ///
554 /// Similar to [`Connection::chain_from_iter`], but allows passing file descriptors with each
555 /// call. Each item in the iterator is a tuple of a method type (or [`Call`]) and its
556 /// associated file descriptors.
557 ///
558 /// # Errors
559 ///
560 /// Returns [`Error::EmptyChain`] if the iterator is empty.
561 ///
562 /// # Examples
563 ///
564 /// ```no_run
565 /// use zlink_core::Connection;
566 /// use serde::{Serialize, Deserialize};
567 /// use serde_prefix_all::prefix_all;
568 /// use std::os::fd::OwnedFd;
569 ///
570 /// # async fn example() -> zlink_core::Result<()> {
571 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
572 ///
573 /// #[prefix_all("org.example.")]
574 /// #[derive(Debug, Serialize, Deserialize)]
575 /// #[serde(tag = "method", content = "parameters")]
576 /// enum Methods {
577 /// SendFile { name: String },
578 /// }
579 ///
580 /// #[derive(Debug, Deserialize)]
581 /// struct FileResult { success: bool }
582 ///
583 /// #[derive(Debug, zlink_core::ReplyError)]
584 /// #[zlink(interface = "org.example", crate = "zlink_core")]
585 /// enum ApiError {
586 /// SendFailed { reason: String },
587 /// }
588 ///
589 /// let calls_with_fds: Vec<(Methods, Vec<OwnedFd>)> = vec![
590 /// (Methods::SendFile { name: "file1.txt".into() }, vec![/* fd1 */]),
591 /// (Methods::SendFile { name: "file2.txt".into() }, vec![/* fd2 */]),
592 /// ];
593 ///
594 /// let replies = conn
595 /// .chain_from_iter_with_fds::<Methods, _, _>(calls_with_fds)?
596 /// .send::<FileResult, ApiError>()
597 /// .await?;
598 /// # Ok(())
599 /// # }
600 /// ```
601 ///
602 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
603 #[cfg(feature = "std")]
604 pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
605 &'c mut self,
606 calls: MethodCalls,
607 ) -> Result<Chain<'c, S>>
608 where
609 Method: Serialize + Debug,
610 MethodCall: Into<Call<Method>>,
611 MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
612 {
613 let mut iter = calls.into_iter();
614 let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
615 let first: Call<Method> = first.into();
616 let mut chain = Chain::new(self, &first, first_fds)?;
617
618 for (call, fds) in iter {
619 let call: Call<Method> = call.into();
620 chain = chain.append(&call, fds)?;
621 }
622
623 Ok(chain)
624 }
625
626 /// Get the peer credentials.
627 ///
628 /// This method caches the credentials on the first call.
629 #[cfg(feature = "std")]
630 pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
631 where
632 S::ReadHalf: socket::FetchPeerCredentials,
633 {
634 if self.credentials.is_none() {
635 let creds = self.read.read_half().fetch_peer_credentials().await?;
636 self.credentials = Some(std::sync::Arc::new(creds));
637 }
638
639 // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
640 // method doesn't error out.
641 Ok(self.credentials.as_ref().unwrap())
642 }
643}
644
645impl<S> From<S> for Connection<S>
646where
647 S: Socket,
648{
649 fn from(socket: S) -> Self {
650 Self::new(socket)
651 }
652}
653
654pub(crate) const BUFFER_SIZE: usize = 256;
655const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
656
657static NEXT_ID: AtomicUsize = AtomicUsize::new(0);