Skip to main content

zlink_core/connection/chain/
reply_stream.rs

1use alloc::boxed::Box;
2use core::{
3    fmt::Debug,
4    pin::Pin,
5    task::{Context, Poll},
6};
7use futures_util::stream::{Stream, unfold};
8use serde::de::DeserializeOwned;
9
10use crate::{
11    Result,
12    connection::{ReadConnection, socket::ReadHalf},
13    reply,
14};
15
16#[cfg(feature = "std")]
17use std::os::fd::OwnedFd;
18
19/// Type alias for chain reply results.
20///
21/// In std mode, includes file descriptors received with the reply.
22/// In no_std mode, just the reply result.
23#[cfg(feature = "std")]
24pub(crate) type ChainResult<Params, ReplyError> =
25    (reply::Result<Params, ReplyError>, alloc::vec::Vec<OwnedFd>);
26
27#[cfg(not(feature = "std"))]
28pub(crate) type ChainResult<Params, ReplyError> = reply::Result<Params, ReplyError>;
29
30/// A stream of replies from a chain of method calls.
31///
32/// # Owned Data Requirement
33///
34/// Stream items must use owned types (`DeserializeOwned`) rather than borrowed types. This is
35/// because the internal buffer may be reused between stream iterations, which would invalidate
36/// borrowed references. This limitation may be lifted in the future when Rust supports lending
37/// streams.
38///
39/// This is used internally by the proxy macro for streaming methods.
40pub struct ReplyStream<'c, Params, ReplyError> {
41    inner: InnerStream<'c, Params, ReplyError>,
42}
43
44impl<'c, Params, ReplyError> ReplyStream<'c, Params, ReplyError>
45where
46    Params: DeserializeOwned + Debug,
47    ReplyError: DeserializeOwned + Debug,
48{
49    /// Create a new reply stream.
50    ///
51    /// The stream will yield `reply_count` replies from the connection.
52    pub fn new<Read>(connection: &'c mut ReadConnection<Read>, reply_count: usize) -> Self
53    where
54        Read: ReadHalf + 'c,
55    {
56        // State is (connection, current_index). The connection reference flows through each
57        // iteration.
58        let inner = unfold(
59            (connection, 0),
60            move |(conn, mut current_index)| async move {
61                if current_index >= reply_count {
62                    return None;
63                }
64
65                let item = conn.receive_reply::<Params, ReplyError>().await;
66                let item_ref = item.as_ref();
67                #[cfg(feature = "std")]
68                // In std mode, we need to ignore the FDs.
69                let item_ref = item_ref.map(|r| &r.0);
70
71                // Update index based on result.
72                match item_ref {
73                    Ok(Ok(r)) if r.continues() != Some(true) => {
74                        current_index += 1;
75                    }
76                    Ok(Ok(_)) => {
77                        // Streaming reply, don't increment index yet.
78                    }
79                    Ok(Err(_)) => {
80                        // For method errors, always increment since there won't be more
81                        // replies.
82                        current_index += 1;
83                    }
84                    Err(_) => {
85                        // General error, mark stream as done.
86                        current_index = reply_count;
87                    }
88                }
89
90                Some((item, (conn, current_index)))
91            },
92        );
93
94        Self {
95            inner: Box::pin(inner),
96        }
97    }
98}
99
100impl<Params, ReplyError> Debug for ReplyStream<'_, Params, ReplyError> {
101    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
102        f.debug_struct("ReplyStream").finish_non_exhaustive()
103    }
104}
105
106impl<Params, ReplyError> Stream for ReplyStream<'_, Params, ReplyError> {
107    type Item = Result<ChainResult<Params, ReplyError>>;
108
109    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110        self.inner.as_mut().poll_next(cx)
111    }
112}
113
114/// The boxed inner stream type for `ReplyStream`.
115type InnerStream<'c, Params, ReplyError> =
116    Pin<Box<dyn Stream<Item = Result<ChainResult<Params, ReplyError>>> + 'c>>;