Skip to main content

std_modrpc/role_impls/
multi_stream_receiver.rs

1use core::{
2    cell::RefCell,
3    marker::PhantomData,
4};
5use crate::{
6    proto::{
7        MultiStreamInitState,
8        MultiStreamItem,
9        MultiStreamItemLazy,
10        MultiStreamId,
11        MultiStreamReceiverConfig,
12    },
13    receive_stream::{ReceiveStream, StreamState},
14};
15use std::collections::HashMap;
16use std::rc::Rc;
17use modrpc::RoleSetup;
18
19pub enum ReceiveMultiStreamNextError {
20    Shutdown,
21    DecodeItem(mproto::DecodeError),
22}
23
24pub struct ReceiveMultiStream<T> {
25    stream_id: MultiStreamId,
26    receive_stream: ReceiveStream,
27    phantom: PhantomData<T>,
28}
29
30struct BrokerState {
31    streams: RefCell<HashMap<u32, Rc<StreamState>>>,
32}
33
34pub struct MultiStreamReceiver<T> {
35    hooks: crate::MultiStreamReceiverHooks<T>,
36    broker_state: Rc<BrokerState>,
37}
38
39pub struct MultiStreamReceiverBuilder<T> {
40    name: &'static str,
41    hooks: crate::MultiStreamReceiverHooks<T>,
42    stubs: crate::MultiStreamReceiverStubs<T>,
43
44    broker_state: Rc<BrokerState>,
45}
46
47impl<T: mproto::Owned> MultiStreamReceiver<T> {
48    pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
49        let receive_stream = ReceiveStream::new(next_seq);
50        self.broker_state.streams.borrow_mut()
51            .insert(stream_id.id, receive_stream.stream_state().clone());
52
53        ReceiveMultiStream {
54            stream_id,
55            receive_stream,
56            phantom: PhantomData,
57        }
58    }
59}
60
61impl<T> Clone for MultiStreamReceiver<T> {
62    fn clone(&self) -> Self {
63        Self {
64            hooks: self.hooks.clone(),
65            broker_state: self.broker_state.clone(),
66        }
67    }
68}
69
70impl<T: mproto::Owned> ReceiveMultiStream<T> {
71    pub fn id(&self) -> MultiStreamId {
72        self.stream_id
73    }
74
75    pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
76        use mproto::BaseLen;
77
78        let packet = self.receive_stream.next_packet().await;
79
80        let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
81            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
82        )
83        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
84
85        let owned_result = stream_item.payload()
86            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
87            .map(|i| T::lazy_to_owned(i))
88            .transpose()
89            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
90
91        Ok(owned_result)
92    }
93
94    pub async fn next_lazy(&mut self)
95        -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
96    {
97        use mproto::BaseLen;
98
99        let packet = self.receive_stream.next_packet().await;
100        packet.advance(modrpc::TransmitPacket::BASE_LEN);
101
102        let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
103
104        Ok(stream_item.map(|s| s.payload().unwrap()))
105    }
106
107    pub fn try_next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
108        use mproto::BaseLen;
109
110        let Some(packet) = self.receive_stream.try_next_packet() else {
111            return Ok(None);
112        };
113
114        let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
115            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
116        )
117        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
118
119        let owned_result = stream_item.payload()
120            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
121            .map(|i| T::lazy_to_owned(i))
122            .transpose()
123            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
124
125        Ok(owned_result)
126    }
127
128    pub fn with_try_next<R>(
129        &mut self,
130        f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
131    ) -> R {
132        use mproto::BaseLen;
133
134        let Some(packet) = self.receive_stream.try_next_packet() else {
135            return f(None);
136        };
137
138        let stream_item =
139            match mproto::decode_value::<MultiStreamItemLazy<T>>(
140                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
141            ) {
142                Ok(x) => x,
143                Err(e) => {
144                    return f(Some(Err(e)));
145                },
146            };
147
148        let payload =
149            match stream_item.payload() {
150                Ok(x) => x,
151                Err(e) => {
152                    return f(Some(Err(e)));
153                },
154            };
155
156        f(Some(Ok(payload)))
157    }
158
159    pub async fn with_next<'a, Fut, R>(
160        &mut self,
161        f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
162    ) -> Option<R>
163        where Fut: std::future::Future<Output = R>
164    {
165        use mproto::BaseLen;
166
167        let packet = self.receive_stream.next_packet().await;
168
169        let stream_item =
170            match mproto::decode_value::<MultiStreamItemLazy<T>>(
171                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
172            ) {
173                Ok(x) => x,
174                Err(e) => {
175                    return Some(f(Err(e)).await);
176                },
177            };
178
179        let payload =
180            match stream_item.payload() {
181                Ok(x) => x,
182                Err(e) => {
183                    return Some(f(Err(e)).await);
184                },
185            };
186
187        Some(f(Ok(payload)).await)
188    }
189
190    pub async fn with_next_sync<'a, R>(
191        &mut self,
192        f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
193    ) -> Option<R> {
194        use mproto::BaseLen;
195
196        let packet = self.receive_stream.next_packet().await;
197
198        let stream_item =
199            match mproto::decode_value::<MultiStreamItemLazy<T>>(
200                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
201            ) {
202                Ok(x) => x,
203                Err(e) => {
204                    return Some(f(Err(e)));
205                }
206            };
207
208        let payload =
209            match stream_item.payload() {
210                Ok(Some(x)) => x,
211                Ok(None) => {
212                    // End of stream
213                    return None;
214                }
215                Err(e) => {
216                    return Some(f(Err(e)));
217                }
218            };
219
220        Some(f(Ok(payload)))
221    }
222
223    pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
224        let mut collected = Vec::new();
225        while let Some(item) = self.next().await? {
226            collected.push(item);
227        }
228        Ok(collected)
229    }
230}
231
232#[derive(thiserror::Error, Debug)]
233pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
234    #[error("failed to decode MultiStream item")]
235    DecodeError(#[from] mproto::DecodeError),
236    #[error("stream sender failed: {0:?}")]
237    SenderError(E),
238    #[error("plane is shutting down")]
239    Shutdown,
240}
241
242impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
243    fn from(other: ReceiveMultiStreamNextError) -> Self {
244        match other {
245            ReceiveMultiStreamNextError::DecodeItem(e) =>
246                MultiStreamTryCollectError::DecodeError(e),
247            ReceiveMultiStreamNextError::Shutdown =>
248                MultiStreamTryCollectError::Shutdown,
249        }
250    }
251}
252
253impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
254    pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
255        let mut collected = Vec::new();
256        while let Some(item) =
257            self.next().await?
258                .transpose()
259                .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
260        {
261            collected.push(item);
262        }
263        Ok(collected)
264    }
265}
266
267impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
268    pub fn new(
269        name: &'static str,
270        hooks: crate::MultiStreamReceiverHooks<T>,
271        stubs: crate::MultiStreamReceiverStubs<T>,
272        _config: &MultiStreamReceiverConfig,
273        _init: MultiStreamInitState,
274    ) -> Self {
275        Self {
276            name, hooks, stubs,
277            broker_state: Rc::new(BrokerState {
278                streams: RefCell::new(HashMap::new()),
279            }),
280        }
281    }
282
283    pub fn create_handle(
284        &self,
285        _setup: &RoleSetup,
286    ) -> crate::MultiStreamReceiver<T> {
287        crate::MultiStreamReceiver {
288            hooks: self.hooks.clone(),
289            broker_state: self.broker_state.clone(),
290        }
291    }
292
293    pub fn build(
294        self,
295        setup: &RoleSetup,
296    ) {
297        use mproto::BaseLen;
298
299        let broker_state = self.broker_state;
300        self.stubs.item.inline_untyped(setup, move |_source, packet| {
301            let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
302            let (seq, stream_id, shutdown) = {
303                let Ok(stream_item) =
304                    mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
305                else {
306                    return;
307                };
308                let Ok(seq) = stream_item.seq() else {
309                    return;
310                };
311                let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
312                    return;
313                };
314                let Ok(payload) = stream_item.payload() else {
315                    return;
316                };
317                (seq, stream_id, payload.is_none())
318            };
319
320            let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
321                log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
322                return;
323            };
324
325            let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
326            if stream_is_done {
327                log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
328                broker_state.streams.borrow_mut().remove(&stream_id);
329            }
330        })
331        .subscribe();
332    }
333}
334