zlink_core/connection/mod.rs
1//! Contains connection related API.
2
3#[cfg(feature = "std")]
4mod credentials;
5mod read_connection;
6#[cfg(feature = "std")]
7pub use credentials::Credentials;
8pub use read_connection::ReadConnection;
9#[cfg(feature = "std")]
10pub use rustix::{process::Pid, process::Uid};
11pub mod chain;
12pub mod socket;
13#[cfg(test)]
14mod tests;
15mod write_connection;
16use crate::{
17 reply::{self, Reply},
18 Call, Result,
19};
20#[cfg(feature = "std")]
21use alloc::vec;
22pub use chain::Chain;
23use core::{fmt::Debug, sync::atomic::AtomicUsize};
24#[cfg(feature = "std")]
25use socket::FetchPeerCredentials;
26pub use write_connection::WriteConnection;
27
28use serde::{Deserialize, Serialize};
29pub use socket::Socket;
30
31// Type alias for receive methods - std returns FDs, no_std doesn't
32#[cfg(feature = "std")]
33type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
34#[cfg(not(feature = "std"))]
35type RecvResult<T> = T;
36
37/// A connection.
38///
39/// The low-level API to send and receive messages.
40///
41/// Each connection gets a unique identifier when created that can be queried using
42/// [`Connection::id`]. This ID is shared betwen the read and write halves of the connection. It
43/// can be used to associate the read and write halves of the same connection.
44///
45/// # Cancel safety
46///
47/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
48/// documentation.
49#[derive(Debug)]
50pub struct Connection<S: Socket> {
51 read: ReadConnection<S::ReadHalf>,
52 write: WriteConnection<S::WriteHalf>,
53 #[cfg(feature = "std")]
54 credentials: Option<std::sync::Arc<Credentials>>,
55}
56
57impl<S> Connection<S>
58where
59 S: Socket,
60{
61 /// Create a new connection.
62 pub fn new(socket: S) -> Self {
63 let (read, write) = socket.split();
64 let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
65 Self {
66 read: ReadConnection::new(read, id),
67 write: WriteConnection::new(write, id),
68 #[cfg(feature = "std")]
69 credentials: None,
70 }
71 }
72
73 /// The reference to the read half of the connection.
74 pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
75 &self.read
76 }
77
78 /// The mutable reference to the read half of the connection.
79 pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
80 &mut self.read
81 }
82
83 /// The reference to the write half of the connection.
84 pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
85 &self.write
86 }
87
88 /// The mutable reference to the write half of the connection.
89 pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
90 &mut self.write
91 }
92
93 /// Split the connection into read and write halves.
94 ///
95 /// Note: This consumes any cached credentials. If you need the credentials after splitting,
96 /// call [`Connection::peer_credentials`] before splitting.
97 pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
98 (self.read, self.write)
99 }
100
101 /// Join the read and write halves into a connection (the opposite of [`Connection::split`]).
102 pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
103 Self {
104 read,
105 write,
106 #[cfg(feature = "std")]
107 credentials: None,
108 }
109 }
110
111 /// The unique identifier of the connection.
112 pub fn id(&self) -> usize {
113 assert_eq!(self.read.id(), self.write.id());
114 self.read.id()
115 }
116
117 /// Sends a method call.
118 ///
119 /// Convenience wrapper around [`WriteConnection::send_call`].
120 pub async fn send_call<Method>(
121 &mut self,
122 call: &Call<Method>,
123 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
124 ) -> Result<()>
125 where
126 Method: Serialize + Debug,
127 {
128 #[cfg(feature = "std")]
129 {
130 self.write.send_call(call, fds).await
131 }
132 #[cfg(not(feature = "std"))]
133 {
134 self.write.send_call(call).await
135 }
136 }
137
138 /// Receives a method call reply.
139 ///
140 /// Convenience wrapper around [`ReadConnection::receive_reply`].
141 pub async fn receive_reply<'r, ReplyParams, ReplyError>(
142 &'r mut self,
143 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
144 where
145 ReplyParams: Deserialize<'r> + Debug,
146 ReplyError: Deserialize<'r> + Debug,
147 {
148 self.read.receive_reply().await
149 }
150
151 /// Call a method and receive a reply.
152 ///
153 /// This is a convenience method that combines [`Connection::send_call`] and
154 /// [`Connection::receive_reply`].
155 pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
156 &'r mut self,
157 call: &Call<Method>,
158 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
159 ) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
160 where
161 Method: Serialize + Debug,
162 ReplyParams: Deserialize<'r> + Debug,
163 ReplyError: Deserialize<'r> + Debug,
164 {
165 #[cfg(feature = "std")]
166 self.send_call(call, fds).await?;
167 #[cfg(not(feature = "std"))]
168 self.send_call(call).await?;
169
170 self.receive_reply().await
171 }
172
173 /// Receive a method call over the socket.
174 ///
175 /// Convenience wrapper around [`ReadConnection::receive_call`].
176 pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
177 where
178 Method: Deserialize<'m> + Debug,
179 {
180 self.read.receive_call().await
181 }
182
183 /// Send a reply over the socket.
184 ///
185 /// Convenience wrapper around [`WriteConnection::send_reply`].
186 pub async fn send_reply<ReplyParams>(
187 &mut self,
188 reply: &Reply<ReplyParams>,
189 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
190 ) -> Result<()>
191 where
192 ReplyParams: Serialize + Debug,
193 {
194 #[cfg(feature = "std")]
195 {
196 self.write.send_reply(reply, fds).await
197 }
198 #[cfg(not(feature = "std"))]
199 {
200 self.write.send_reply(reply).await
201 }
202 }
203
204 /// Send an error reply over the socket.
205 ///
206 /// Convenience wrapper around [`WriteConnection::send_error`].
207 pub async fn send_error<ReplyError>(
208 &mut self,
209 error: &ReplyError,
210 #[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
211 ) -> Result<()>
212 where
213 ReplyError: Serialize + Debug,
214 {
215 #[cfg(feature = "std")]
216 {
217 self.write.send_error(error, fds).await
218 }
219 #[cfg(not(feature = "std"))]
220 {
221 self.write.send_error(error).await
222 }
223 }
224
225 /// Enqueue a call to the server.
226 ///
227 /// Convenience wrapper around [`WriteConnection::enqueue_call`].
228 pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
229 where
230 Method: Serialize + Debug,
231 {
232 #[cfg(feature = "std")]
233 {
234 self.write.enqueue_call(method, vec![])
235 }
236 #[cfg(not(feature = "std"))]
237 {
238 self.write.enqueue_call(method)
239 }
240 }
241
242 /// Flush the connection.
243 ///
244 /// Convenience wrapper around [`WriteConnection::flush`].
245 pub async fn flush(&mut self) -> Result<()> {
246 self.write.flush().await
247 }
248
249 /// Start a chain of method calls.
250 ///
251 /// This allows batching multiple calls together and sending them in a single write operation.
252 ///
253 /// # Examples
254 ///
255 /// ## Basic Usage with Sequential Access
256 ///
257 /// ```no_run
258 /// use zlink_core::{Connection, Call, reply};
259 /// use serde::{Serialize, Deserialize};
260 /// use serde_prefix_all::prefix_all;
261 /// use futures_util::{pin_mut, stream::StreamExt};
262 ///
263 /// # async fn example() -> zlink_core::Result<()> {
264 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
265 ///
266 /// #[prefix_all("org.example.")]
267 /// #[derive(Debug, Serialize, Deserialize)]
268 /// #[serde(tag = "method", content = "parameters")]
269 /// enum Methods {
270 /// GetUser { id: u32 },
271 /// GetProject { id: u32 },
272 /// }
273 ///
274 /// #[derive(Debug, Deserialize)]
275 /// struct User { name: String }
276 ///
277 /// #[derive(Debug, Deserialize)]
278 /// struct Project { title: String }
279 ///
280 /// #[derive(Debug, zlink_core::ReplyError)]
281 /// #[zlink(
282 /// interface = "org.example",
283 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
284 /// crate = "zlink_core",
285 /// )]
286 /// enum ApiError {
287 /// UserNotFound { code: i32 },
288 /// ProjectNotFound { code: i32 },
289 /// }
290 ///
291 /// let get_user = Call::new(Methods::GetUser { id: 1 });
292 /// let get_project = Call::new(Methods::GetProject { id: 2 });
293 ///
294 /// // Chain calls and send them in a batch
295 /// # #[cfg(feature = "std")]
296 /// let replies = conn
297 /// .chain_call::<Methods, User, ApiError>(&get_user, vec![])?
298 /// .append(&get_project, vec![])?
299 /// .send().await?;
300 /// # #[cfg(not(feature = "std"))]
301 /// # let replies = conn
302 /// # .chain_call::<Methods, User, ApiError>(&get_user)?
303 /// # .append(&get_project)?
304 /// # .send().await?;
305 /// pin_mut!(replies);
306 ///
307 /// // Access replies sequentially - types are now fixed by the chain
308 /// # #[cfg(feature = "std")]
309 /// # {
310 /// let (user_reply, _fds) = replies.next().await.unwrap()?;
311 /// let (project_reply, _fds) = replies.next().await.unwrap()?;
312 ///
313 /// match user_reply {
314 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
315 /// Err(error) => println!("User error: {:?}", error),
316 /// }
317 /// # }
318 /// # #[cfg(not(feature = "std"))]
319 /// # {
320 /// # let user_reply = replies.next().await.unwrap()?;
321 /// # let _project_reply = replies.next().await.unwrap()?;
322 /// #
323 /// # match user_reply {
324 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
325 /// # Err(error) => println!("User error: {:?}", error),
326 /// # }
327 /// # }
328 /// # Ok(())
329 /// # }
330 /// ```
331 ///
332 /// ## Arbitrary Number of Calls
333 ///
334 /// ```no_run
335 /// # use zlink_core::{Connection, Call, reply};
336 /// # use serde::{Serialize, Deserialize};
337 /// # use futures_util::{pin_mut, stream::StreamExt};
338 /// # use serde_prefix_all::prefix_all;
339 /// # async fn example() -> zlink_core::Result<()> {
340 /// # let mut conn: Connection<zlink_core::connection::socket::impl_for_doc::Socket> = todo!();
341 /// # #[prefix_all("org.example.")]
342 /// # #[derive(Debug, Serialize, Deserialize)]
343 /// # #[serde(tag = "method", content = "parameters")]
344 /// # enum Methods {
345 /// # GetUser { id: u32 },
346 /// # }
347 /// # #[derive(Debug, Deserialize)]
348 /// # struct User { name: String }
349 /// # #[derive(Debug, zlink_core::ReplyError)]
350 /// #[zlink(
351 /// interface = "org.example",
352 /// // Not needed in the real code because you'll use `ReplyError` through `zlink` crate.
353 /// crate = "zlink_core",
354 /// )]
355 /// # enum ApiError {
356 /// # UserNotFound { code: i32 },
357 /// # ProjectNotFound { code: i32 },
358 /// # }
359 /// # let get_user = Call::new(Methods::GetUser { id: 1 });
360 ///
361 /// // Chain many calls (no upper limit)
362 /// # #[cfg(feature = "std")]
363 /// let mut chain = conn.chain_call::<Methods, User, ApiError>(&get_user, vec![])?;
364 /// # #[cfg(not(feature = "std"))]
365 /// # let mut chain = conn.chain_call::<Methods, User, ApiError>(&get_user)?;
366 /// # #[cfg(feature = "std")]
367 /// for i in 2..100 {
368 /// chain = chain.append(&Call::new(Methods::GetUser { id: i }), vec![])?;
369 /// }
370 /// # #[cfg(not(feature = "std"))]
371 /// # for i in 2..100 {
372 /// # chain = chain.append(&Call::new(Methods::GetUser { id: i }))?;
373 /// # }
374 ///
375 /// let replies = chain.send().await?;
376 /// pin_mut!(replies);
377 ///
378 /// // Process all replies sequentially - types are fixed by the chain
379 /// # #[cfg(feature = "std")]
380 /// while let Some(result) = replies.next().await {
381 /// let (user_reply, _fds) = result?;
382 /// // Handle each reply...
383 /// match user_reply {
384 /// Ok(user) => println!("User: {}", user.parameters().unwrap().name),
385 /// Err(error) => println!("Error: {:?}", error),
386 /// }
387 /// }
388 /// # #[cfg(not(feature = "std"))]
389 /// # while let Some(result) = replies.next().await {
390 /// # let user_reply = result?;
391 /// # // Handle each reply...
392 /// # match user_reply {
393 /// # Ok(user) => println!("User: {}", user.parameters().unwrap().name),
394 /// # Err(error) => println!("Error: {:?}", error),
395 /// # }
396 /// # }
397 /// # Ok(())
398 /// # }
399 /// ```
400 ///
401 /// # Performance Benefits
402 ///
403 /// Instead of multiple write operations, the chain sends all calls in a single
404 /// write operation, reducing context switching and therefore minimizing latency.
405 pub fn chain_call<'c, Method, ReplyParams, ReplyError>(
406 &'c mut self,
407 call: &Call<Method>,
408 #[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
409 ) -> Result<Chain<'c, S, ReplyParams, ReplyError>>
410 where
411 Method: Serialize + Debug,
412 ReplyParams: Deserialize<'c> + Debug,
413 ReplyError: Deserialize<'c> + Debug,
414 {
415 Chain::new(
416 self,
417 call,
418 #[cfg(feature = "std")]
419 fds,
420 )
421 }
422
423 /// Get the peer credentials.
424 ///
425 /// This method caches the credentials on the first call.
426 #[cfg(feature = "std")]
427 pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
428 where
429 S::ReadHalf: socket::FetchPeerCredentials,
430 {
431 if self.credentials.is_none() {
432 let creds = self.read.read_half().fetch_peer_credentials().await?;
433 self.credentials = Some(std::sync::Arc::new(creds));
434 }
435
436 // Safety: `unwrap` won't panic because we ensure above that it's set correctly if the
437 // method doesn't error out.
438 Ok(self.credentials.as_ref().unwrap())
439 }
440}
441
442impl<S> From<S> for Connection<S>
443where
444 S: Socket,
445{
446 fn from(socket: S) -> Self {
447 Self::new(socket)
448 }
449}
450
451pub(crate) const BUFFER_SIZE: usize = 256;
452const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024; // Don't allow buffers over 100MB.
453
454static NEXT_ID: AtomicUsize = AtomicUsize::new(0);