zlink_core/connection/mod.rs
1//! Contains connection related API.
2
3#[cfg(feature = "std")]
4mod credentials;
5mod read_connection;
6#[cfg(feature = "std")]
7pub use credentials::Credentials;
8pub use read_connection::ReadConnection;
9#[cfg(feature = "std")]
10pub use rustix::{process::Pid, process::Uid};
11pub mod chain;
12pub mod socket;
13#[cfg(test)]
14mod tests;
15mod write_connection;
16use crate::{
17 reply::{self, Reply},
18 Call, Result,
19};
20#[cfg(feature = "std")]
21use alloc::vec;
22pub use chain::Chain;
23use core::{fmt::Debug, sync::atomic::AtomicUsize};
24#[cfg(feature = "std")]
25use socket::FetchPeerCredentials;
26pub use write_connection::WriteConnection;
27
28use serde::{Deserialize, Serialize};
29pub use socket::Socket;
30
31// Type alias for receive methods - std returns FDs, no_std doesn't
32#[cfg(feature = "std")]
33type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
34#[cfg(not(feature = "std"))]
35type RecvResult<T> = T;
36
37/// A connection.
38///
39/// The low-level API to send and receive messages.
40///
41/// Each connection gets a unique identifier when created that can be queried using
42/// [`Connection::id`]. This ID is shared between the read and write halves of the connection. It
43/// can be used to associate the read and write halves of the same connection.
44///
45/// # Cancel safety
46///
47/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
48/// documentation.
49#[derive(Debug)]
50pub struct Connection<S: Socket> {
51 read: ReadConnection<S::ReadHalf>,
52 write: WriteConnection<S::WriteHalf>,
53 #[cfg(feature = "std")]
54 credentials: Option<std::sync::Arc<Credentials>>,
55}
56
57impl<S> Connection<S>
58where
59 S: Socket,
60{
61 /// Create a new connection.
62 pub fn new(socket: S) -> Self {
63 let (read, write) = socket.split();
64 let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
65 Self {
66 read: ReadConnection::new(read, id),
67 write: WriteConnection::new(write, id),
68 #[cfg(feature = "std")]
69 credentials: None,
70 }
71 }
72
73 /// The reference to the read half of the connection.
74 pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
75 &self.read
76 }
77
78 /// The mutable reference to the read half of the connection.
79 pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
80 &mut self.read
81 }
82
83 /// The reference to the write half of the connection.
84 pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
85 &self.write
86 }
87
88 /// The mutable reference to the write half of the connection.
89 pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
90 &mut self.write
91 }
92
93 /// Split the connection into read and write halves.
94 ///
95 /// Note: This consumes any cached credentials. If you need the credentials after splitting,
96 /// call [`Connection::peer_credentials`] before splitting.
97 pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
98 (self.read, self.write)
99 }
100
101 /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
102 pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
103 Self {
104 read,
105 write,
106 #[cfg(feature = "std")]
107 credentials: None,
108 }
109 }
110
111 /// The unique identifier of the connection.
112 pub fn id(&self) -> usize {
113 assert_eq!(self.read.id(), self.write.id());
114 self.read.id()
115 }
116
117 /// Sends a method call.
118 ///
119 /// Convenience wrapper around [`WriteConnection::send_call`].
120 pub async fn send_call<Method>(
121 &mut self,
122 call: &Call<Method>,
123 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
124 ) -> Result<()>
125 where
126 Method: Serialize + Debug,
127 {
128 #[cfg(feature = "std")]
129 {
130 self.write.send_call(call, fds).await
131 }
132 #[cfg(not(feature = "std"))]
133 {
134 self.write.send_call(call).await
135 }
136 }
137
138 /// Receives a method call reply.
139 ///
140 /// Convenience wrapper around [`ReadConnection::receive_reply`].
141 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
142 &'r mut self,
143 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
144 where
145 ReplyParams: Deserialize<'r> + Debug,
146 ReplyError: Deserialize<'r> + Debug,
147 {
148 self.read.receive_reply().await
149 }
150
151 /// Call a method and receive a reply.
152 ///
153 /// This is a convenience method that combines [`Connection::send_call`] and
154 /// [`Connection::receive_reply`].
155 pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
156 &'r mut self,
157 call: &Call<Method>,
158 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
159 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
160 where
161 Method: Serialize + Debug,
162 ReplyParams: Deserialize<'r> + Debug,
163 ReplyError: Deserialize<'r> + Debug,
164 {
165 #[cfg(feature = "std")]
166 self.send_call(call, fds).await?;
167 #[cfg(not(feature = "std"))]
168 self.send_call(call).await?;
169
170 self.receive_reply().await
171 }
172
173 /// Receive a method call over the socket.
174 ///
175 /// Convenience wrapper around [`ReadConnection::receive_call`].
176 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
177 where
178 Method: Deserialize<'m> + Debug,
179 {
180 self.read.receive_call().await
181 }
182
183 /// Send a reply over the socket.
184 ///
185 /// Convenience wrapper around [`WriteConnection::send_reply`].
186 pub async fn send_reply<ReplyParams>(
187 &mut self,
188 reply: &Reply<ReplyParams>,
189 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
190 ) -> Result<()>
191 where
192 ReplyParams: Serialize + Debug,
193 {
194 #[cfg(feature = "std")]
195 {
196 self.write.send_reply(reply, fds).await
197 }
198 #[cfg(not(feature = "std"))]
199 {
200 self.write.send_reply(reply).await
201 }
202 }
203
204 /// Send an error reply over the socket.
205 ///
206 /// Convenience wrapper around [`WriteConnection::send_error`].
207 pub async fn send_error<ReplyError>(
208 &mut self,
209 error: &ReplyError,
210 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
211 ) -> Result<()>
212 where
213 ReplyError: Serialize + Debug,
214 {
215 #[cfg(feature = "std")]
216 {
217 self.write.send_error(error, fds).await
218 }
219 #[cfg(not(feature = "std"))]
220 {
221 self.write.send_error(error).await
222 }
223 }
224
225 /// Enqueue a call to the server.
226 ///
227 /// Convenience wrapper around [`WriteConnection::enqueue_call`].
228 pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
229 where
230 Method: Serialize + Debug,
231 {
232 #[cfg(feature = "std")]
233 {
234 self.write.enqueue_call(method, vec![])
235 }
236 #[cfg(not(feature = "std"))]
237 {
238 self.write.enqueue_call(method)
239 }
240 }
241
242 /// Flush the connection.
243 ///
244 /// Convenience wrapper around [`WriteConnection::flush`].
245 pub async fn flush(&mut self) -> Result<()> {
246 self.write.flush().await
247 }
248
249 /// Start a chain of method calls.
250 ///
251 /// This allows batching multiple calls together and sending them in a single write operation.
252 ///
253 /// # Examples
254 ///
255 /// ## Basic Usage with Sequential Access
256 ///
257 /// ```no_run
258 /// use zlink_core::{Connection, Call, reply};
259 /// use serde::{Serialize, Deserialize};
260 /// use serde_prefix_all::prefix_all;
261 /// use futures_util::{pin_mut, stream::StreamExt};
262 ///
263 /// # async fn example() -> zlink_core::Result<()> {
264 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
265 ///
266 /// #[prefix_all("org.example.")]
267 /// #[derive(Debug, Serialize, Deserialize)]
268 /// #[serde(tag = "method", content = "parameters")]
269 /// enum Methods {
270 /// GetUser { id: u32 },
271 /// GetProject { id: u32 },
272 /// }
273 ///
274 /// #[derive(Debug, Deserialize)]
275 /// struct User { name: String }
276 ///
277 /// #[derive(Debug, Deserialize)]
278 /// struct Project { title: String }
279 ///
280 /// #[derive(Debug, zlink_core::ReplyError)]
281 /// #[zlink(
282 /// interface = "org.example",
283 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
284 /// crate = "zlink_core",
285 /// )]
286 /// enum ApiError {
287 /// UserNotFound { code: i32 },
288 /// ProjectNotFound { code: i32 },
289 /// }
290 ///
291 /// let get_user = Call::new(Methods::GetUser { id: 1 });
292 /// let get_project = Call::new(Methods::GetProject { id: 2 });
293 ///
294 /// // Chain calls and send them in a batch
295 /// # #[cfg(feature = "std")]
296 /// let replies = conn
297 /// .chain_call::<Methods>(&get_user, vec![])?
298 /// .append(&get_project, vec![])?
299 /// .send::<User, ApiError>().await?;
300 /// # #[cfg(not(feature = "std"))]
301 /// # let replies = conn
302 /// # .chain_call::<Methods>(&get_user)?
303 /// # .append(&get_project)?
304 /// # .send::<User, ApiError>().await?;
305 /// pin_mut!(replies);
306 ///
307 /// // Access replies sequentially.
308 /// # #[cfg(feature = "std")]
309 /// # {
310 /// let (user_reply, _fds) = replies.next().await.unwrap()?;
311 /// let (project_reply, _fds) = replies.next().await.unwrap()?;
312 ///
313 /// match user_reply {
314 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
315 /// Err(error) => println!("User error: {:?}", error),
316 /// }
317 /// # }
318 /// # #[cfg(not(feature = "std"))]
319 /// # {
320 /// # let user_reply = replies.next().await.unwrap()?;
321 /// # let _project_reply = replies.next().await.unwrap()?;
322 /// #
323 /// # match user_reply {
324 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
325 /// # Err(error) => println!("User error: {:?}", error),
326 /// # }
327 /// # }
328 /// # Ok(())
329 /// # }
330 /// ```
331 ///
332 /// ## Arbitrary Number of Calls
333 ///
334 /// ```no_run
335 /// # use zlink_core::{Connection, Call, reply};
336 /// # use serde::{Serialize, Deserialize};
337 /// # use futures_util::{pin_mut, stream::StreamExt};
338 /// # use serde_prefix_all::prefix_all;
339 /// # async fn example() -> zlink_core::Result<()> {
340 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
341 /// # #[prefix_all("org.example.")]
342 /// # #[derive(Debug, Serialize, Deserialize)]
343 /// # #[serde(tag = "method", content = "parameters")]
344 /// # enum Methods {
345 /// # GetUser { id: u32 },
346 /// # }
347 /// # #[derive(Debug, Deserialize)]
348 /// # struct User { name: String }
349 /// # #[derive(Debug, zlink_core::ReplyError)]
350 /// #[zlink(
351 /// interface = "org.example",
352 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
353 /// crate = "zlink_core",
354 /// )]
355 /// # enum ApiError {
356 /// # UserNotFound { code: i32 },
357 /// # ProjectNotFound { code: i32 },
358 /// # }
359 /// # let get_user = Call::new(Methods::GetUser { id: 1 });
360 ///
361 /// // Chain many calls (no upper limit)
362 /// # #[cfg(feature = "std")]
363 /// let mut chain = conn.chain_call::<Methods>(&get_user, vec![])?;
364 /// # #[cfg(not(feature = "std"))]
365 /// # let mut chain = conn.chain_call::<Methods>(&get_user)?;
366 /// # #[cfg(feature = "std")]
367 /// for i in 2..100 {
368 /// chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
369 /// }
370 /// # #[cfg(not(feature = "std"))]
371 /// # for i in 2..100 {
372 /// # chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
373 /// # }
374 ///
375 /// let replies = chain.send::<User, ApiError>().await?;
376 /// pin_mut!(replies);
377 ///
378 /// // Process all replies sequentially.
379 /// # #[cfg(feature = "std")]
380 /// while let Some(result) = replies.next().await {
381 /// let (user_reply, _fds) = result?;
382 /// // Handle each reply...
383 /// match user_reply {
384 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
385 /// Err(error) => println!("Error: {:?}", error),
386 /// }
387 /// }
388 /// # #[cfg(not(feature = "std"))]
389 /// # while let Some(result) = replies.next().await {
390 /// # let user_reply = result?;
391 /// # // Handle each reply...
392 /// # match user_reply {
393 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
394 /// # Err(error) => println!("Error: {:?}", error),
395 /// # }
396 /// # }
397 /// # Ok(())
398 /// # }
399 /// ```
400 ///
401 /// # Performance Benefits
402 ///
403 /// Instead of multiple write operations, the chain sends all calls in a single
404 /// write operation, reducing context switching and therefore minimizing latency.
405 pub fn chain_call<'c, Method>(
406 &'c mut self,
407 call: &Call<Method>,
408 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
409 ) -> Result<Chain<'c, S>>
410 where
411 Method: Serialize + Debug,
412 {
413 Chain::new(
414 self,
415 call,
416 #[cfg(feature = "std")]
417 fds,
418 )
419 }
420
421 /// Create a chain from an iterator of method calls.
422 ///
423 /// This allows creating a chain from any iterator yielding method types or calls. Each item
424 /// is automatically converted to a [`Call`] via [`Into<Call<Method>>`]. Unlike
425 /// [`Connection::chain_call`], this method allows building chains from dynamically-sized
426 /// collections.
427 ///
428 /// # Errors
429 ///
430 /// Returns [`Error::EmptyChain`] if the iterator is empty.
431 ///
432 /// # Examples
433 ///
434 /// ```no_run
435 /// use zlink_core::Connection;
436 /// use serde::{Serialize, Deserialize};
437 /// use serde_prefix_all::prefix_all;
438 /// use futures_util::{pin_mut, stream::StreamExt};
439 ///
440 /// # async fn example() -> zlink_core::Result<()> {
441 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
442 ///
443 /// #[prefix_all("org.example.")]
444 /// #[derive(Debug, Serialize, Deserialize)]
445 /// #[serde(tag = "method", content = "parameters")]
446 /// enum Methods {
447 /// GetUser { id: u32 },
448 /// }
449 ///
450 /// #[derive(Debug, Deserialize)]
451 /// struct User { name: String }
452 ///
453 /// #[derive(Debug, zlink_core::ReplyError)]
454 /// #[zlink(interface = "org.example", crate = "zlink_core")]
455 /// enum ApiError {
456 /// UserNotFound { code: i32 },
457 /// }
458 ///
459 /// let user_ids = [1, 2, 3, 4, 5];
460 /// let replies = conn
461 /// .chain_from_iter::<Methods, _, _>(
462 /// user_ids.iter().map(|&id| Methods::GetUser { id })
463 /// )?
464 /// .send::<User, ApiError>()
465 /// .await?;
466 /// pin_mut!(replies);
467 ///
468 /// # #[cfg(feature = "std")]
469 /// while let Some(result) = replies.next().await {
470 /// let (user_reply, _fds) = result?;
471 /// // Handle each reply...
472 /// }
473 /// # Ok(())
474 /// # }
475 /// ```
476 ///
477 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
478 pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
479 &'c mut self,
480 calls: MethodCalls,
481 ) -> Result<Chain<'c, S>>
482 where
483 Method: Serialize + Debug,
484 MethodCall: Into<Call<Method>>,
485 MethodCalls: IntoIterator<Item = MethodCall>,
486 {
487 let mut iter = calls.into_iter();
488 let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
489
490 #[cfg(feature = "std")]
491 let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
492 #[cfg(not(feature = "std"))]
493 let mut chain = Chain::new(self, &first)?;
494
495 for call in iter {
496 let call: Call<Method> = call.into();
497 #[cfg(feature = "std")]
498 {
499 chain = chain.append(&call, alloc::vec::Vec::new())?;
500 }
501 #[cfg(not(feature = "std"))]
502 {
503 chain = chain.append(&call)?;
504 }
505 }
506
507 Ok(chain)
508 }
509
510 /// Create a chain from an iterator of method calls with file descriptors.
511 ///
512 /// Similar to [`Connection::chain_from_iter`], but allows passing file descriptors with each
513 /// call. Each item in the iterator is a tuple of a method type (or [`Call`]) and its
514 /// associated file descriptors.
515 ///
516 /// # Errors
517 ///
518 /// Returns [`Error::EmptyChain`] if the iterator is empty.
519 ///
520 /// # Examples
521 ///
522 /// ```no_run
523 /// use zlink_core::Connection;
524 /// use serde::{Serialize, Deserialize};
525 /// use serde_prefix_all::prefix_all;
526 /// use std::os::fd::OwnedFd;
527 ///
528 /// # async fn example() -> zlink_core::Result<()> {
529 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
530 ///
531 /// #[prefix_all("org.example.")]
532 /// #[derive(Debug, Serialize, Deserialize)]
533 /// #[serde(tag = "method", content = "parameters")]
534 /// enum Methods {
535 /// SendFile { name: String },
536 /// }
537 ///
538 /// #[derive(Debug, Deserialize)]
539 /// struct FileResult { success: bool }
540 ///
541 /// #[derive(Debug, zlink_core::ReplyError)]
542 /// #[zlink(interface = "org.example", crate = "zlink_core")]
543 /// enum ApiError {
544 /// SendFailed { reason: String },
545 /// }
546 ///
547 /// let calls_with_fds: Vec<(Methods, Vec<OwnedFd>)> = vec![
548 /// (Methods::SendFile { name: "file1.txt".into() }, vec![/* fd1 */]),
549 /// (Methods::SendFile { name: "file2.txt".into() }, vec![/* fd2 */]),
550 /// ];
551 ///
552 /// let replies = conn
553 /// .chain_from_iter_with_fds::<Methods, _, _>(calls_with_fds)?
554 /// .send::<FileResult, ApiError>()
555 /// .await?;
556 /// # Ok(())
557 /// # }
558 /// ```
559 ///
560 /// [`Error::EmptyChain`]: crate::Error::EmptyChain
561 #[cfg(feature = "std")]
562 pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
563 &'c mut self,
564 calls: MethodCalls,
565 ) -> Result<Chain<'c, S>>
566 where
567 Method: Serialize + Debug,
568 MethodCall: Into<Call<Method>>,
569 MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
570 {
571 let mut iter = calls.into_iter();
572 let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
573 let first: Call<Method> = first.into();
574 let mut chain = Chain::new(self, &first, first_fds)?;
575
576 for (call, fds) in iter {
577 let call: Call<Method> = call.into();
578 chain = chain.append(&call, fds)?;
579 }
580
581 Ok(chain)
582 }
583
584 /// Get the peer credentials.
585 ///
586 /// This method caches the credentials on the first call.
587 #[cfg(feature = "std")]
588 pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
589 where
590 S::ReadHalf: socket::FetchPeerCredentials,
591 {
592 if self.credentials.is_none() {
593 let creds = self.read.read_half().fetch_peer_credentials().await?;
594 self.credentials = Some(std::sync::Arc::new(creds));
595 }
596
597 // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
598 // method doesn't error out.
599 Ok(self.credentials.as_ref().unwrap())
600 }
601}
602
603impl<S> From<S> for Connection<S>
604where
605 S: Socket,
606{
607 fn from(socket: S) -> Self {
608 Self::new(socket)
609 }
610}
611
612pub(crate) const BUFFER_SIZE: usize = 256;
613const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
614
615static NEXT_ID: AtomicUsize = AtomicUsize::new(0);