std_modrpc/role_impls/
multi_stream_receiver.rs

1use crate::proto::{
2    MultiStreamId, MultiStreamInitState, MultiStreamItem, MultiStreamItemLazy,
3    MultiStreamReceiverConfig,
4};
5use core::cell::{Cell, RefCell};
6use core::cmp::Reverse;
7use core::marker::PhantomData;
8use modrpc::RoleSetup;
9use std::collections::{BinaryHeap, HashMap};
10use std::rc::Rc;
11
12pub enum ReceiveMultiStreamNextError {
13    Shutdown,
14    DecodeItem(mproto::DecodeError),
15}
16
17pub struct ReceiveMultiStream<T> {
18    stream_id: MultiStreamId,
19    local_queue_rx: localq::mpsc::Receiver<modrpc::Packet>,
20    stream_state: Rc<StreamState>,
21    phantom: PhantomData<T>,
22}
23
24struct BrokerState {
25    streams: RefCell<HashMap<u32, Rc<StreamState>>>,
26}
27
28pub struct MultiStreamReceiver<T> {
29    hooks: crate::MultiStreamReceiverHooks<T>,
30    broker_state: Rc<BrokerState>,
31}
32
33pub struct MultiStreamReceiverBuilder<T> {
34    name: &'static str,
35    hooks: crate::MultiStreamReceiverHooks<T>,
36    stubs: crate::MultiStreamReceiverStubs<T>,
37
38    broker_state: Rc<BrokerState>,
39}
40
41impl<T: mproto::Owned> MultiStreamReceiver<T> {
42    pub fn new_stream(
43        &self,
44        stream_id: MultiStreamId,
45        next_seq: Option<u64>,
46    ) -> ReceiveMultiStream<T> {
47        // TODO use a waker cell instead of a channel
48        let (local_queue_tx, local_queue_rx) = localq::mpsc::channel(1);
49        let stream_state = Rc::new(StreamState::new(local_queue_tx, next_seq));
50        self.broker_state
51            .streams
52            .borrow_mut()
53            .insert(stream_id.id, stream_state.clone());
54
55        ReceiveMultiStream {
56            stream_id,
57            local_queue_rx,
58            stream_state,
59            phantom: PhantomData,
60        }
61    }
62}
63
64impl<T> Clone for MultiStreamReceiver<T> {
65    fn clone(&self) -> Self {
66        Self {
67            hooks: self.hooks.clone(),
68            broker_state: self.broker_state.clone(),
69        }
70    }
71}
72
73impl<T: mproto::Owned> ReceiveMultiStream<T> {
74    pub fn id(&self) -> MultiStreamId {
75        self.stream_id
76    }
77
78    fn try_next_packet(&mut self) -> Option<modrpc::Packet> {
79        if let Ok(packet) = self.local_queue_rx.try_recv() {
80            return Some(packet);
81        }
82
83        self.stream_state.try_pop()
84    }
85
86    async fn next_packet(&mut self) -> modrpc::Packet {
87        if let Ok(packet) = self.local_queue_rx.try_recv() {
88            return packet;
89        }
90
91        if let Some(packet) = self.stream_state.try_pop() {
92            return packet;
93        }
94
95        self.local_queue_rx.recv().await.unwrap()
96    }
97
98    pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
99        use mproto::BaseLen;
100
101        let packet = self.next_packet().await;
102
103        let stream_item: MultiStreamItemLazy<T> =
104            mproto::decode_value(&packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..])
105                .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
106
107        let owned_result = stream_item
108            .payload()
109            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
110            .map(|i| T::lazy_to_owned(i))
111            .transpose()
112            .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
113
114        Ok(owned_result)
115    }
116
117    pub async fn next_lazy(
118        &mut self,
119    ) -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError> {
120        use mproto::BaseLen;
121
122        let packet = self.next_packet().await;
123        packet.advance(modrpc::TransmitPacket::BASE_LEN);
124
125        let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
126
127        Ok(stream_item.map(|s| s.payload().unwrap()))
128    }
129
130    pub fn with_try_next<R>(
131        &mut self,
132        f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
133    ) -> R {
134        use mproto::BaseLen;
135
136        let Some(packet) = self.try_next_packet() else {
137            return f(None);
138        };
139
140        let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
141            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
142        ) {
143            Ok(x) => x,
144            Err(e) => {
145                return f(Some(Err(e)));
146            }
147        };
148
149        let payload = match stream_item.payload() {
150            Ok(x) => x,
151            Err(e) => {
152                return f(Some(Err(e)));
153            }
154        };
155
156        f(Some(Ok(payload)))
157    }
158
159    pub async fn with_next<'a, Fut, R>(
160        &mut self,
161        f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
162    ) -> Option<R>
163    where
164        Fut: std::future::Future<Output = R>,
165    {
166        use mproto::BaseLen;
167
168        let packet = self.next_packet().await;
169
170        let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
171            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
172        ) {
173            Ok(x) => x,
174            Err(e) => {
175                return Some(f(Err(e)).await);
176            }
177        };
178
179        let payload = match stream_item.payload() {
180            Ok(x) => x,
181            Err(e) => {
182                return Some(f(Err(e)).await);
183            }
184        };
185
186        Some(f(Ok(payload)).await)
187    }
188
189    pub async fn with_next_sync<'a, R>(
190        &mut self,
191        f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
192    ) -> Option<R> {
193        use mproto::BaseLen;
194
195        let packet = self.next_packet().await;
196
197        let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
198            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
199        ) {
200            Ok(x) => x,
201            Err(e) => {
202                return Some(f(Err(e)));
203            }
204        };
205
206        let payload = match stream_item.payload() {
207            Ok(Some(x)) => x,
208            Ok(None) => {
209                // End of stream
210                return None;
211            }
212            Err(e) => {
213                return Some(f(Err(e)));
214            }
215        };
216
217        Some(f(Ok(payload)))
218    }
219
220    pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
221        let mut collected = Vec::new();
222        while let Some(item) = self.next().await? {
223            collected.push(item);
224        }
225        Ok(collected)
226    }
227}
228
229#[derive(thiserror::Error, Debug)]
230pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
231    #[error("failed to decode MultiStream item")]
232    DecodeError(#[from] mproto::DecodeError),
233    #[error("stream sender failed: {0:?}")]
234    SenderError(E),
235    #[error("plane is shutting down")]
236    Shutdown,
237}
238
239impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
240    fn from(other: ReceiveMultiStreamNextError) -> Self {
241        match other {
242            ReceiveMultiStreamNextError::DecodeItem(e) => {
243                MultiStreamTryCollectError::DecodeError(e)
244            }
245            ReceiveMultiStreamNextError::Shutdown => MultiStreamTryCollectError::Shutdown,
246        }
247    }
248}
249
250impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
251    pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
252        let mut collected = Vec::new();
253        while let Some(item) = self
254            .next()
255            .await?
256            .transpose()
257            .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
258        {
259            collected.push(item);
260        }
261        Ok(collected)
262    }
263}
264
265impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
266    pub fn new(
267        name: &'static str,
268        hooks: crate::MultiStreamReceiverHooks<T>,
269        stubs: crate::MultiStreamReceiverStubs<T>,
270        _config: &MultiStreamReceiverConfig,
271        _init: MultiStreamInitState,
272    ) -> Self {
273        Self {
274            name,
275            hooks,
276            stubs,
277            broker_state: Rc::new(BrokerState {
278                streams: RefCell::new(HashMap::new()),
279            }),
280        }
281    }
282
283    pub fn create_handle(&self, _setup: &RoleSetup) -> crate::MultiStreamReceiver<T> {
284        crate::MultiStreamReceiver {
285            hooks: self.hooks.clone(),
286            broker_state: self.broker_state.clone(),
287        }
288    }
289
290    pub fn build(self, setup: &RoleSetup) {
291        use mproto::BaseLen;
292
293        let broker_state = self.broker_state;
294        self.stubs
295            .item
296            .inline_untyped(setup, move |_source, packet| {
297                let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
298                let (seq, stream_id, shutdown) = {
299                    let Ok(stream_item) =
300                        mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
301                    else {
302                        return;
303                    };
304                    let Ok(seq) = stream_item.seq() else {
305                        return;
306                    };
307                    let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
308                        return;
309                    };
310                    let Ok(payload) = stream_item.payload() else {
311                        return;
312                    };
313                    (seq, stream_id, payload.is_none())
314                };
315
316                let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned()
317                else {
318                    log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
319                    return;
320                };
321
322                let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
323                if stream_is_done {
324                    log::debug!(
325                        "MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}"
326                    );
327                    broker_state.streams.borrow_mut().remove(&stream_id);
328                }
329            })
330            .subscribe();
331    }
332}
333
334// Wrapper for MultiStreamItem that is Eq + PartialEq + Ord + PartialOrd
335
336struct OrderedItem {
337    seq: u64,
338    shutdown: bool,
339    packet: modrpc::Packet,
340}
341
342impl PartialEq for OrderedItem {
343    fn eq(&self, other: &Self) -> bool {
344        self.seq.eq(&other.seq)
345    }
346}
347
348impl Eq for OrderedItem {}
349
350impl PartialOrd for OrderedItem {
351    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
352        self.seq.partial_cmp(&other.seq)
353    }
354}
355
356impl Ord for OrderedItem {
357    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
358        self.seq.cmp(&other.seq)
359    }
360}
361
362// Shared state of a single stream receiver
363
364struct StreamState {
365    heap: RefCell<BinaryHeap<Reverse<OrderedItem>>>,
366    first_seq: Cell<u64>,
367    last_seq: Cell<Option<u64>>,
368    received_count: Cell<u64>,
369    next_seq: Cell<Option<u64>>,
370    local_queue_tx: localq::mpsc::Sender<modrpc::Packet>,
371}
372
373impl StreamState {
374    fn new(local_queue_tx: localq::mpsc::Sender<modrpc::Packet>, next_seq: Option<u64>) -> Self {
375        Self {
376            heap: RefCell::new(BinaryHeap::new()),
377            first_seq: Cell::new(0),
378            last_seq: Cell::new(None),
379            received_count: Cell::new(0),
380            next_seq: Cell::new(next_seq),
381            local_queue_tx,
382        }
383    }
384
385    fn try_pop(&self) -> Option<modrpc::Packet> {
386        let mut heap = self.heap.borrow_mut();
387        let Reverse(stream_item) = heap.peek()?;
388
389        let next_seq = self.next_seq.get().unwrap_or_else(|| {
390            self.first_seq.set(stream_item.seq);
391            stream_item.seq
392        });
393
394        if stream_item.seq != next_seq {
395            return None;
396        }
397        self.next_seq.set(Some(next_seq + 1));
398
399        Some(heap.pop().unwrap().0.packet)
400    }
401
402    /// Returns true if the stream is finished and should be cleaned up.
403    fn handle_item(&self, seq: u64, shutdown: bool, packet: modrpc::Packet) -> bool {
404        let mut heap = self.heap.borrow_mut();
405
406        // If we don't know the next seq, treat the first item we get as the start of the stream.
407        let next_seq = self.next_seq.get().unwrap_or_else(|| {
408            self.first_seq.set(seq);
409            seq
410        });
411        // If we subsequently receive earlier items, we drop them.
412        if seq < next_seq {
413            return false;
414        }
415
416        // Reverse order so that heap produces item with smallest seq.
417        heap.push(Reverse(OrderedItem {
418            seq,
419            shutdown,
420            packet,
421        }));
422        self.received_count.set(self.received_count.get() + 1);
423        if shutdown {
424            self.last_seq.set(Some(seq));
425        }
426
427        while let Some(Reverse(stream_item)) = heap.peek() {
428            if stream_item.seq != next_seq {
429                break;
430            }
431
432            // Unwrap guaranteed to succeed.
433            let Reverse(stream_item) = heap.pop().unwrap();
434
435            if let Err(localq::mpsc::TrySendError::Full(packet)) =
436                self.local_queue_tx.try_send(stream_item.packet)
437            {
438                heap.push(Reverse(OrderedItem {
439                    seq: stream_item.seq,
440                    shutdown: stream_item.shutdown,
441                    packet,
442                }));
443                break;
444            }
445
446            self.next_seq.set(Some(next_seq + 1));
447        }
448
449        if let Some(last_seq) = self.last_seq.get() {
450            (last_seq - self.first_seq.get() + 1) == self.received_count.get()
451        } else {
452            false
453        }
454    }
455}