zlink_core/connection/chain/
reply_stream.rs

1use alloc::boxed::Box;
2use core::{
3    fmt::Debug,
4    pin::Pin,
5    task::{Context, Poll},
6};
7use futures_util::stream::{unfold, Stream};
8use serde::de::DeserializeOwned;
9
10use crate::{
11    connection::{socket::ReadHalf, ReadConnection},
12    reply, Result,
13};
14
15#[cfg(feature = "std")]
16use std::os::fd::OwnedFd;
17
18/// Type alias for chain reply results.
19///
20/// In std mode, includes file descriptors received with the reply.
21/// In no_std mode, just the reply result.
22#[cfg(feature = "std")]
23pub(crate) type ChainResult<Params, ReplyError> =
24    (reply::Result<Params, ReplyError>, alloc::vec::Vec<OwnedFd>);
25
26#[cfg(not(feature = "std"))]
27pub(crate) type ChainResult<Params, ReplyError> = reply::Result<Params, ReplyError>;
28
29/// A stream of replies from a chain of method calls.
30///
31/// # Owned Data Requirement
32///
33/// Stream items must use owned types (`DeserializeOwned`) rather than borrowed types. This is
34/// because the internal buffer may be reused between stream iterations, which would invalidate
35/// borrowed references. This limitation may be lifted in the future when Rust supports lending
36/// streams.
37///
38/// This is used internally by the proxy macro for streaming methods.
39pub struct ReplyStream<'c, Params, ReplyError> {
40    inner: InnerStream<'c, Params, ReplyError>,
41}
42
43impl<'c, Params, ReplyError> ReplyStream<'c, Params, ReplyError>
44where
45    Params: DeserializeOwned + Debug,
46    ReplyError: DeserializeOwned + Debug,
47{
48    /// Create a new reply stream.
49    ///
50    /// The stream will yield `reply_count` replies from the connection.
51    pub fn new<Read>(connection: &'c mut ReadConnection<Read>, reply_count: usize) -> Self
52    where
53        Read: ReadHalf + 'c,
54    {
55        // State is (connection, current_index). The connection reference flows through each
56        // iteration.
57        let inner = unfold(
58            (connection, 0),
59            move |(conn, mut current_index)| async move {
60                if current_index >= reply_count {
61                    return None;
62                }
63
64                let item = conn.receive_reply::<Params, ReplyError>().await;
65                let item_ref = item.as_ref();
66                #[cfg(feature = "std")]
67                // In std mode, we need to ignore the FDs.
68                let item_ref = item_ref.map(|r| &r.0);
69
70                // Update index based on result.
71                match item_ref {
72                    Ok(Ok(r)) if r.continues() != Some(true) => {
73                        current_index += 1;
74                    }
75                    Ok(Ok(_)) => {
76                        // Streaming reply, don't increment index yet.
77                    }
78                    Ok(Err(_)) => {
79                        // For method errors, always increment since there won't be more
80                        // replies.
81                        current_index += 1;
82                    }
83                    Err(_) => {
84                        // General error, mark stream as done.
85                        current_index = reply_count;
86                    }
87                }
88
89                Some((item, (conn, current_index)))
90            },
91        );
92
93        Self {
94            inner: Box::pin(inner),
95        }
96    }
97}
98
99impl<Params, ReplyError> Debug for ReplyStream<'_, Params, ReplyError> {
100    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
101        f.debug_struct("ReplyStream").finish_non_exhaustive()
102    }
103}
104
105impl<Params, ReplyError> Stream for ReplyStream<'_, Params, ReplyError> {
106    type Item = Result<ChainResult<Params, ReplyError>>;
107
108    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
109        self.inner.as_mut().poll_next(cx)
110    }
111}
112
113/// The boxed inner stream type for `ReplyStream`.
114type InnerStream<'c, Params, ReplyError> =
115    Pin<Box<dyn Stream<Item = Result<ChainResult<Params, ReplyError>>> + 'c>>;