Skip to main content

std_modrpc/role_impls/
multi_stream_receiver.rs

1use core::{
2    cell::RefCell,
3    marker::PhantomData,
4};
5use crate::{
6    proto::{
7        MultiStreamInitState,
8        MultiStreamItem,
9        MultiStreamItemLazy,
10        MultiStreamId,
11        MultiStreamReceiverConfig,
12    },
13    receive_stream::{ReceiveStream, StreamState},
14};
15use std::collections::HashMap;
16use std::rc::Rc;
17use modrpc::RoleSetup;
18
19pub enum ReceiveMultiStreamNextError {
20    Shutdown,
21    DecodeItem(mproto::DecodeError),
22}
23
24pub struct ReceiveMultiStream<T> {
25    stream_id: MultiStreamId,
26    receive_stream: ReceiveStream,
27    phantom: PhantomData<T>,
28}
29
30struct BrokerState {
31    streams: RefCell<HashMap<u32, Rc<StreamState>>>,
32}
33
34pub struct MultiStreamReceiver<T> {
35    hooks: crate::MultiStreamReceiverHooks<T>,
36    broker_state: Rc<BrokerState>,
37}
38
39pub struct MultiStreamReceiverBuilder<T> {
40    name: &'static str,
41    hooks: crate::MultiStreamReceiverHooks<T>,
42    stubs: crate::MultiStreamReceiverStubs<T>,
43
44    broker_state: Rc<BrokerState>,
45}
46
47impl<T: mproto::Owned> MultiStreamReceiver<T> {
48    pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
49        let receive_stream = ReceiveStream::new(next_seq);
50        self.broker_state.streams.borrow_mut()
51            .insert(stream_id.id, receive_stream.stream_state().clone());
52
53        ReceiveMultiStream {
54            stream_id,
55            receive_stream,
56            phantom: PhantomData,
57        }
58    }
59}
60
61impl<T> Clone for MultiStreamReceiver<T> {
62    fn clone(&self) -> Self {
63        Self {
64            hooks: self.hooks.clone(),
65            broker_state: self.broker_state.clone(),
66        }
67    }
68}
69
70impl<T: mproto::Owned> ReceiveMultiStream<T> {
71    pub fn id(&self) -> MultiStreamId {
72        self.stream_id
73    }
74
75    pub async fn next(&mut self) -> Result<T, ReceiveMultiStreamNextError> {
76        use mproto::BaseLen;
77
78        let packet = self.receive_stream.next_packet().await;
79
80        let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
81            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
82        )
83        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
84
85        let owned_result = T::lazy_to_owned(
86            stream_item.payload()
87                .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
88                .ok_or(ReceiveMultiStreamNextError::Shutdown)?
89        )
90        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
91
92        Ok(owned_result)
93    }
94
95    pub async fn next_lazy(&mut self)
96        -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
97    {
98        use mproto::BaseLen;
99
100        let packet = self.receive_stream.next_packet().await;
101        packet.advance(modrpc::TransmitPacket::BASE_LEN);
102
103        let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
104
105        Ok(stream_item.map(|s| s.payload().unwrap()))
106    }
107
108    pub fn try_next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
109        use mproto::BaseLen;
110
111        let Some(packet) = self.receive_stream.try_next_packet() else {
112            return Ok(None);
113        };
114
115        let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
116            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
117        )
118        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
119
120        let owned_result = T::lazy_to_owned(
121            stream_item.payload()
122                .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
123                .ok_or(ReceiveMultiStreamNextError::Shutdown)?
124        )
125        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
126
127        Ok(Some(owned_result))
128    }
129
130    pub fn with_try_next<R>(
131        &mut self,
132        f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
133    ) -> R {
134        use mproto::BaseLen;
135
136        let Some(packet) = self.receive_stream.try_next_packet() else {
137            return f(None);
138        };
139
140        let stream_item =
141            match mproto::decode_value::<MultiStreamItemLazy<T>>(
142                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
143            ) {
144                Ok(x) => x,
145                Err(e) => {
146                    return f(Some(Err(e)));
147                },
148            };
149
150        let payload =
151            match stream_item.payload() {
152                Ok(x) => x,
153                Err(e) => {
154                    return f(Some(Err(e)));
155                },
156            };
157
158        f(Some(Ok(payload)))
159    }
160
161    pub async fn with_next<'a, Fut, R>(
162        &mut self,
163        f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
164    ) -> Option<R>
165        where Fut: std::future::Future<Output = R>
166    {
167        use mproto::BaseLen;
168
169        let packet = self.receive_stream.next_packet().await;
170
171        let stream_item =
172            match mproto::decode_value::<MultiStreamItemLazy<T>>(
173                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
174            ) {
175                Ok(x) => x,
176                Err(e) => {
177                    return Some(f(Err(e)).await);
178                },
179            };
180
181        let payload =
182            match stream_item.payload() {
183                Ok(x) => x,
184                Err(e) => {
185                    return Some(f(Err(e)).await);
186                },
187            };
188
189        Some(f(Ok(payload)).await)
190    }
191
192    pub async fn with_next_sync<'a, R>(
193        &mut self,
194        f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
195    ) -> Option<R> {
196        use mproto::BaseLen;
197
198        let packet = self.receive_stream.next_packet().await;
199
200        let stream_item =
201            match mproto::decode_value::<MultiStreamItemLazy<T>>(
202                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
203            ) {
204                Ok(x) => x,
205                Err(e) => {
206                    return Some(f(Err(e)));
207                }
208            };
209
210        let payload =
211            match stream_item.payload() {
212                Ok(Some(x)) => x,
213                Ok(None) => {
214                    // End of stream
215                    return None;
216                }
217                Err(e) => {
218                    return Some(f(Err(e)));
219                }
220            };
221
222        Some(f(Ok(payload)))
223    }
224
225    pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
226        let mut collected = Vec::new();
227        loop {
228            match self.next().await {
229                Ok(item) => { collected.push(item); Ok(()) }
230                Err(ReceiveMultiStreamNextError::Shutdown) => break,
231                Err(e) => Err(e),
232            }?;
233        }
234        Ok(collected)
235    }
236}
237
238#[derive(thiserror::Error, Debug)]
239pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
240    #[error("failed to decode MultiStream item")]
241    DecodeError(#[from] mproto::DecodeError),
242    #[error("stream sender failed: {0:?}")]
243    SenderError(E),
244    #[error("plane is shutting down")]
245    Shutdown,
246}
247
248impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
249    fn from(other: ReceiveMultiStreamNextError) -> Self {
250        match other {
251            ReceiveMultiStreamNextError::DecodeItem(e) =>
252                MultiStreamTryCollectError::DecodeError(e),
253            ReceiveMultiStreamNextError::Shutdown =>
254                MultiStreamTryCollectError::Shutdown,
255        }
256    }
257}
258
259impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
260    pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
261        let mut collected = Vec::new();
262        loop {
263            match self.next().await {
264                Ok(Ok(item)) => { collected.push(item); Ok(()) }
265                Ok(Err(e)) => return Err(MultiStreamTryCollectError::SenderError(e.into())),
266                Err(ReceiveMultiStreamNextError::Shutdown) => break,
267                Err(e) => Err(e),
268            }?;
269        }
270        Ok(collected)
271    }
272}
273
274impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
275    pub fn new(
276        name: &'static str,
277        hooks: crate::MultiStreamReceiverHooks<T>,
278        stubs: crate::MultiStreamReceiverStubs<T>,
279        _config: &MultiStreamReceiverConfig,
280        _init: MultiStreamInitState,
281    ) -> Self {
282        Self {
283            name, hooks, stubs,
284            broker_state: Rc::new(BrokerState {
285                streams: RefCell::new(HashMap::new()),
286            }),
287        }
288    }
289
290    pub fn create_handle(
291        &self,
292        _setup: &RoleSetup,
293    ) -> crate::MultiStreamReceiver<T> {
294        crate::MultiStreamReceiver {
295            hooks: self.hooks.clone(),
296            broker_state: self.broker_state.clone(),
297        }
298    }
299
300    pub fn build(
301        self,
302        setup: &RoleSetup,
303    ) {
304        use mproto::BaseLen;
305
306        let broker_state = self.broker_state;
307        self.stubs.item.inline_untyped(setup, move |_source, packet| {
308            let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
309            let (seq, stream_id, shutdown) = {
310                let Ok(stream_item) =
311                    mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
312                else {
313                    return;
314                };
315                let Ok(seq) = stream_item.seq() else {
316                    return;
317                };
318                let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
319                    return;
320                };
321                let Ok(payload) = stream_item.payload() else {
322                    return;
323                };
324                (seq, stream_id, payload.is_none())
325            };
326
327            let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
328                log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
329                return;
330            };
331
332            let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
333            if stream_is_done {
334                log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
335                broker_state.streams.borrow_mut().remove(&stream_id);
336            }
337        })
338        .subscribe();
339    }
340}
341