std_modrpc/role_impls/
multi_stream_receiver.rs

1use core::cell::{Cell, RefCell};
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4use crate::proto::{
5    MultiStreamInitState,
6    MultiStreamItem,
7    MultiStreamItemLazy,
8    MultiStreamId,
9    MultiStreamReceiverConfig,
10};
11use std::collections::{BinaryHeap, HashMap};
12use std::rc::Rc;
13use modrpc::RoleSetup;
14
15pub enum ReceiveMultiStreamNextError {
16    Shutdown,
17    DecodeItem(mproto::DecodeError),
18}
19
20pub struct ReceiveMultiStream<T> {
21    stream_id: MultiStreamId,
22    local_queue_rx: localq::mpsc::Receiver<modrpc::Packet>,
23    stream_state: Rc<StreamState>,
24    phantom: PhantomData<T>,
25}
26
27struct BrokerState {
28    streams: RefCell<HashMap<u32, Rc<StreamState>>>,
29}
30
31pub struct MultiStreamReceiver<T> {
32    hooks: crate::MultiStreamReceiverHooks<T>,
33    broker_state: Rc<BrokerState>,
34}
35
36pub struct MultiStreamReceiverBuilder<T> {
37    name: &'static str,
38    hooks: crate::MultiStreamReceiverHooks<T>,
39    stubs: crate::MultiStreamReceiverStubs<T>,
40
41    broker_state: Rc<BrokerState>,
42}
43
44impl<T: mproto::Owned> MultiStreamReceiver<T> {
45    pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
46        // TODO use a waker cell instead of a channel
47        let (local_queue_tx, local_queue_rx) = localq::mpsc::channel(1);
48        let stream_state = Rc::new(StreamState::new(local_queue_tx, next_seq));
49        self.broker_state.streams.borrow_mut().insert(stream_id.id, stream_state.clone());
50
51        ReceiveMultiStream {
52            stream_id,
53            local_queue_rx,
54            stream_state,
55            phantom: PhantomData,
56        }
57    }
58}
59
60impl<T> Clone for MultiStreamReceiver<T> {
61    fn clone(&self) -> Self {
62        Self {
63            hooks: self.hooks.clone(),
64            broker_state: self.broker_state.clone(),
65        }
66    }
67}
68
69impl<T: mproto::Owned> ReceiveMultiStream<T> {
70    pub fn id(&self) -> MultiStreamId {
71        self.stream_id
72    }
73
74    fn try_next_packet(&mut self) -> Option<modrpc::Packet> {
75        if let Ok(packet) = self.local_queue_rx.try_recv() {
76            return Some(packet);
77        }
78
79        self.stream_state.try_pop()
80    }
81
82    async fn next_packet(&mut self) -> modrpc::Packet {
83        if let Ok(packet) = self.local_queue_rx.try_recv() {
84            return packet;
85        }
86
87        if let Some(packet) = self.stream_state.try_pop() {
88            return packet;
89        }
90
91        self.local_queue_rx.recv().await.unwrap()
92    }
93
94    pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
95        use mproto::BaseLen;
96
97        let packet = self.next_packet().await;
98
99        let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
100            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
101        )
102        .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
103
104        let owned_result = stream_item.payload()
105            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
106            .map(|i| T::lazy_to_owned(i))
107            .transpose()
108            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
109
110        Ok(owned_result)
111    }
112
113    pub async fn next_lazy(&mut self)
114        -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
115    {
116        use mproto::BaseLen;
117
118        let packet = self.next_packet().await;
119        packet.advance(modrpc::TransmitPacket::BASE_LEN);
120
121        let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
122
123        Ok(stream_item.map(|s| s.payload().unwrap()))
124    }
125
126    pub fn with_try_next<R>(
127        &mut self,
128        f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
129    ) -> R {
130        use mproto::BaseLen;
131
132        let Some(packet) = self.try_next_packet() else {
133            return f(None);
134        };
135
136        let stream_item =
137            match mproto::decode_value::<MultiStreamItemLazy<T>>(
138                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
139            ) {
140                Ok(x) => x,
141                Err(e) => {
142                    return f(Some(Err(e)));
143                },
144            };
145
146        let payload =
147            match stream_item.payload() {
148                Ok(x) => x,
149                Err(e) => {
150                    return f(Some(Err(e)));
151                },
152            };
153
154        f(Some(Ok(payload)))
155    }
156
157    pub async fn with_next<'a, Fut, R>(
158        &mut self,
159        f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
160    ) -> Option<R>
161        where Fut: std::future::Future<Output = R>
162    {
163        use mproto::BaseLen;
164
165        let packet = self.next_packet().await;
166
167        let stream_item =
168            match mproto::decode_value::<MultiStreamItemLazy<T>>(
169                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
170            ) {
171                Ok(x) => x,
172                Err(e) => {
173                    return Some(f(Err(e)).await);
174                },
175            };
176
177        let payload =
178            match stream_item.payload() {
179                Ok(x) => x,
180                Err(e) => {
181                    return Some(f(Err(e)).await);
182                },
183            };
184
185        Some(f(Ok(payload)).await)
186    }
187
188    pub async fn with_next_sync<'a, R>(
189        &mut self,
190        f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
191    ) -> Option<R> {
192        use mproto::BaseLen;
193
194        let packet = self.next_packet().await;
195
196        let stream_item =
197            match mproto::decode_value::<MultiStreamItemLazy<T>>(
198                &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
199            ) {
200                Ok(x) => x,
201                Err(e) => {
202                    return Some(f(Err(e)));
203                }
204            };
205
206        let payload =
207            match stream_item.payload() {
208                Ok(Some(x)) => x,
209                Ok(None) => {
210                    // End of stream
211                    return None;
212                }
213                Err(e) => {
214                    return Some(f(Err(e)));
215                }
216            };
217
218        Some(f(Ok(payload)))
219    }
220
221    pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
222        let mut collected = Vec::new();
223        while let Some(item) = self.next().await? {
224            collected.push(item);
225        }
226        Ok(collected)
227    }
228}
229
230#[derive(thiserror::Error, Debug)]
231pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
232    #[error("failed to decode MultiStream item")]
233    DecodeError(#[from] mproto::DecodeError),
234    #[error("stream sender failed: {0:?}")]
235    SenderError(E),
236    #[error("plane is shutting down")]
237    Shutdown,
238}
239
240impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
241    fn from(other: ReceiveMultiStreamNextError) -> Self {
242        match other {
243            ReceiveMultiStreamNextError::DecodeItem(e) =>
244                MultiStreamTryCollectError::DecodeError(e),
245            ReceiveMultiStreamNextError::Shutdown =>
246                MultiStreamTryCollectError::Shutdown,
247        }
248    }
249}
250
251impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
252    pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
253        let mut collected = Vec::new();
254        while let Some(item) =
255            self.next().await?
256                .transpose()
257                .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
258        {
259            collected.push(item);
260        }
261        Ok(collected)
262    }
263}
264
265impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
266    pub fn new(
267        name: &'static str,
268        hooks: crate::MultiStreamReceiverHooks<T>,
269        stubs: crate::MultiStreamReceiverStubs<T>,
270        _config: &MultiStreamReceiverConfig,
271        _init: MultiStreamInitState,
272    ) -> Self {
273        Self {
274            name, hooks, stubs,
275            broker_state: Rc::new(BrokerState {
276                streams: RefCell::new(HashMap::new()),
277            }),
278        }
279    }
280
281    pub fn create_handle(
282        &self,
283        _setup: &RoleSetup,
284    ) -> crate::MultiStreamReceiver<T> {
285        crate::MultiStreamReceiver {
286            hooks: self.hooks.clone(),
287            broker_state: self.broker_state.clone(),
288        }
289    }
290
291    pub fn build(
292        self,
293        setup: &RoleSetup,
294    ) {
295        use mproto::BaseLen;
296
297        let broker_state = self.broker_state;
298        self.stubs.item.inline_untyped(setup, move |_source, packet| {
299            let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
300            let (seq, stream_id, shutdown) = {
301                let Ok(stream_item) =
302                    mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
303                else {
304                    return;
305                };
306                let Ok(seq) = stream_item.seq() else {
307                    return;
308                };
309                let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
310                    return;
311                };
312                let Ok(payload) = stream_item.payload() else {
313                    return;
314                };
315                (seq, stream_id, payload.is_none())
316            };
317
318            let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
319                log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
320                return;
321            };
322
323            let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
324            if stream_is_done {
325                log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
326                broker_state.streams.borrow_mut().remove(&stream_id);
327            }
328        })
329        .subscribe();
330    }
331}
332
333// Wrapper for MultiStreamItem that is Eq + PartialEq + Ord + PartialOrd
334
335struct OrderedItem {
336    seq: u64,
337    shutdown: bool,
338    packet: modrpc::Packet,
339}
340
341impl PartialEq for OrderedItem {
342    fn eq(&self, other: &Self) -> bool { self.seq.eq(&other.seq) }
343}
344
345impl Eq for OrderedItem { }
346
347impl PartialOrd for OrderedItem {
348    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349        self.seq.partial_cmp(&other.seq)
350    }
351}
352
353impl Ord for OrderedItem {
354    fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.seq.cmp(&other.seq) }
355}
356
357// Shared state of a single stream receiver
358
359struct StreamState {
360    heap: RefCell<BinaryHeap<Reverse<OrderedItem>>>,
361    first_seq: Cell<u64>,
362    last_seq: Cell<Option<u64>>,
363    received_count: Cell<u64>,
364    next_seq: Cell<Option<u64>>,
365    local_queue_tx: localq::mpsc::Sender<modrpc::Packet>,
366}
367
368impl StreamState {
369    fn new(local_queue_tx: localq::mpsc::Sender<modrpc::Packet>, next_seq: Option<u64>) -> Self {
370        Self {
371            heap: RefCell::new(BinaryHeap::new()),
372            first_seq: Cell::new(0),
373            last_seq: Cell::new(None),
374            received_count: Cell::new(0),
375            next_seq: Cell::new(next_seq),
376            local_queue_tx,
377        }
378    }
379
380    fn try_pop(&self) -> Option<modrpc::Packet> {
381        let mut heap = self.heap.borrow_mut();
382        let Reverse(stream_item) = heap.peek()?;
383
384        let next_seq = self.next_seq.get().unwrap_or_else(|| {
385            self.first_seq.set(stream_item.seq);
386            stream_item.seq
387        });
388
389        if stream_item.seq != next_seq {
390            return None;
391        }
392        self.next_seq.set(Some(next_seq + 1));
393
394        Some(heap.pop().unwrap().0.packet)
395    }
396
397    /// Returns true if the stream is finished and should be cleaned up.
398    fn handle_item(&self, seq: u64, shutdown: bool, packet: modrpc::Packet) -> bool {
399        let mut heap = self.heap.borrow_mut();
400
401        // If we don't know the next seq, treat the first item we get as the start of the stream.
402        let next_seq = self.next_seq.get().unwrap_or_else(|| {
403            self.first_seq.set(seq);
404            seq
405        });
406        // If we subsequently receive earlier items, we drop them.
407        if seq < next_seq {
408            return false;
409        }
410
411        // Reverse order so that heap produces item with smallest seq.
412        heap.push(Reverse(OrderedItem { seq, shutdown, packet }));
413        self.received_count.set(self.received_count.get() + 1);
414        if shutdown {
415            self.last_seq.set(Some(seq));
416        }
417
418        while let Some(Reverse(stream_item)) = heap.peek() {
419            if stream_item.seq != next_seq { break; }
420
421            // Unwrap guaranteed to succeed.
422            let Reverse(stream_item) = heap.pop().unwrap();
423
424            if let Err(localq::mpsc::TrySendError::Full(packet)) =
425                self.local_queue_tx.try_send(stream_item.packet)
426            {
427                heap.push(Reverse(OrderedItem {
428                    seq: stream_item.seq,
429                    shutdown: stream_item.shutdown,
430                    packet,
431                }));
432                break;
433            }
434
435            self.next_seq.set(Some(next_seq + 1));
436        }
437
438        if let Some(last_seq) = self.last_seq.get() {
439            (last_seq - self.first_seq.get() + 1) == self.received_count.get()
440        } else {
441            false
442        }
443    }
444}
445