1use crate::proto::{
2 MultiStreamId, MultiStreamInitState, MultiStreamItem, MultiStreamItemLazy,
3 MultiStreamReceiverConfig,
4};
5use core::cell::{Cell, RefCell};
6use core::cmp::Reverse;
7use core::marker::PhantomData;
8use modrpc::RoleSetup;
9use std::collections::{BinaryHeap, HashMap};
10use std::rc::Rc;
11
12pub enum ReceiveMultiStreamNextError {
13 Shutdown,
14 DecodeItem(mproto::DecodeError),
15}
16
17pub struct ReceiveMultiStream<T> {
18 stream_id: MultiStreamId,
19 local_queue_rx: localq::mpsc::Receiver<modrpc::Packet>,
20 stream_state: Rc<StreamState>,
21 phantom: PhantomData<T>,
22}
23
24struct BrokerState {
25 streams: RefCell<HashMap<u32, Rc<StreamState>>>,
26}
27
28pub struct MultiStreamReceiver<T> {
29 hooks: crate::MultiStreamReceiverHooks<T>,
30 broker_state: Rc<BrokerState>,
31}
32
33pub struct MultiStreamReceiverBuilder<T> {
34 name: &'static str,
35 hooks: crate::MultiStreamReceiverHooks<T>,
36 stubs: crate::MultiStreamReceiverStubs<T>,
37
38 broker_state: Rc<BrokerState>,
39}
40
41impl<T: mproto::Owned> MultiStreamReceiver<T> {
42 pub fn new_stream(
43 &self,
44 stream_id: MultiStreamId,
45 next_seq: Option<u64>,
46 ) -> ReceiveMultiStream<T> {
47 let (local_queue_tx, local_queue_rx) = localq::mpsc::channel(1);
49 let stream_state = Rc::new(StreamState::new(local_queue_tx, next_seq));
50 self.broker_state
51 .streams
52 .borrow_mut()
53 .insert(stream_id.id, stream_state.clone());
54
55 ReceiveMultiStream {
56 stream_id,
57 local_queue_rx,
58 stream_state,
59 phantom: PhantomData,
60 }
61 }
62}
63
64impl<T> Clone for MultiStreamReceiver<T> {
65 fn clone(&self) -> Self {
66 Self {
67 hooks: self.hooks.clone(),
68 broker_state: self.broker_state.clone(),
69 }
70 }
71}
72
73impl<T: mproto::Owned> ReceiveMultiStream<T> {
74 pub fn id(&self) -> MultiStreamId {
75 self.stream_id
76 }
77
78 fn try_next_packet(&mut self) -> Option<modrpc::Packet> {
79 if let Ok(packet) = self.local_queue_rx.try_recv() {
80 return Some(packet);
81 }
82
83 self.stream_state.try_pop()
84 }
85
86 async fn next_packet(&mut self) -> modrpc::Packet {
87 if let Ok(packet) = self.local_queue_rx.try_recv() {
88 return packet;
89 }
90
91 if let Some(packet) = self.stream_state.try_pop() {
92 return packet;
93 }
94
95 self.local_queue_rx.recv().await.unwrap()
96 }
97
98 pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
99 use mproto::BaseLen;
100
101 let packet = self.next_packet().await;
102
103 let stream_item: MultiStreamItemLazy<T> =
104 mproto::decode_value(&packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..])
105 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
106
107 let owned_result = stream_item
108 .payload()
109 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
110 .map(|i| T::lazy_to_owned(i))
111 .transpose()
112 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
113
114 Ok(owned_result)
115 }
116
117 pub async fn next_lazy(
118 &mut self,
119 ) -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError> {
120 use mproto::BaseLen;
121
122 let packet = self.next_packet().await;
123 packet.advance(modrpc::TransmitPacket::BASE_LEN);
124
125 let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
126
127 Ok(stream_item.map(|s| s.payload().unwrap()))
128 }
129
130 pub fn with_try_next<R>(
131 &mut self,
132 f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
133 ) -> R {
134 use mproto::BaseLen;
135
136 let Some(packet) = self.try_next_packet() else {
137 return f(None);
138 };
139
140 let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
141 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
142 ) {
143 Ok(x) => x,
144 Err(e) => {
145 return f(Some(Err(e)));
146 }
147 };
148
149 let payload = match stream_item.payload() {
150 Ok(x) => x,
151 Err(e) => {
152 return f(Some(Err(e)));
153 }
154 };
155
156 f(Some(Ok(payload)))
157 }
158
159 pub async fn with_next<'a, Fut, R>(
160 &mut self,
161 f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
162 ) -> Option<R>
163 where
164 Fut: std::future::Future<Output = R>,
165 {
166 use mproto::BaseLen;
167
168 let packet = self.next_packet().await;
169
170 let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
171 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
172 ) {
173 Ok(x) => x,
174 Err(e) => {
175 return Some(f(Err(e)).await);
176 }
177 };
178
179 let payload = match stream_item.payload() {
180 Ok(x) => x,
181 Err(e) => {
182 return Some(f(Err(e)).await);
183 }
184 };
185
186 Some(f(Ok(payload)).await)
187 }
188
189 pub async fn with_next_sync<'a, R>(
190 &mut self,
191 f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
192 ) -> Option<R> {
193 use mproto::BaseLen;
194
195 let packet = self.next_packet().await;
196
197 let stream_item = match mproto::decode_value::<MultiStreamItemLazy<T>>(
198 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..],
199 ) {
200 Ok(x) => x,
201 Err(e) => {
202 return Some(f(Err(e)));
203 }
204 };
205
206 let payload = match stream_item.payload() {
207 Ok(Some(x)) => x,
208 Ok(None) => {
209 return None;
211 }
212 Err(e) => {
213 return Some(f(Err(e)));
214 }
215 };
216
217 Some(f(Ok(payload)))
218 }
219
220 pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
221 let mut collected = Vec::new();
222 while let Some(item) = self.next().await? {
223 collected.push(item);
224 }
225 Ok(collected)
226 }
227}
228
229#[derive(thiserror::Error, Debug)]
230pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
231 #[error("failed to decode MultiStream item")]
232 DecodeError(#[from] mproto::DecodeError),
233 #[error("stream sender failed: {0:?}")]
234 SenderError(E),
235 #[error("plane is shutting down")]
236 Shutdown,
237}
238
239impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
240 fn from(other: ReceiveMultiStreamNextError) -> Self {
241 match other {
242 ReceiveMultiStreamNextError::DecodeItem(e) => {
243 MultiStreamTryCollectError::DecodeError(e)
244 }
245 ReceiveMultiStreamNextError::Shutdown => MultiStreamTryCollectError::Shutdown,
246 }
247 }
248}
249
250impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
251 pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
252 let mut collected = Vec::new();
253 while let Some(item) = self
254 .next()
255 .await?
256 .transpose()
257 .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
258 {
259 collected.push(item);
260 }
261 Ok(collected)
262 }
263}
264
265impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
266 pub fn new(
267 name: &'static str,
268 hooks: crate::MultiStreamReceiverHooks<T>,
269 stubs: crate::MultiStreamReceiverStubs<T>,
270 _config: &MultiStreamReceiverConfig,
271 _init: MultiStreamInitState,
272 ) -> Self {
273 Self {
274 name,
275 hooks,
276 stubs,
277 broker_state: Rc::new(BrokerState {
278 streams: RefCell::new(HashMap::new()),
279 }),
280 }
281 }
282
283 pub fn create_handle(&self, _setup: &RoleSetup) -> crate::MultiStreamReceiver<T> {
284 crate::MultiStreamReceiver {
285 hooks: self.hooks.clone(),
286 broker_state: self.broker_state.clone(),
287 }
288 }
289
290 pub fn build(self, setup: &RoleSetup) {
291 use mproto::BaseLen;
292
293 let broker_state = self.broker_state;
294 self.stubs
295 .item
296 .inline_untyped(setup, move |_source, packet| {
297 let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
298 let (seq, stream_id, shutdown) = {
299 let Ok(stream_item) =
300 mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
301 else {
302 return;
303 };
304 let Ok(seq) = stream_item.seq() else {
305 return;
306 };
307 let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
308 return;
309 };
310 let Ok(payload) = stream_item.payload() else {
311 return;
312 };
313 (seq, stream_id, payload.is_none())
314 };
315
316 let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned()
317 else {
318 log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
319 return;
320 };
321
322 let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
323 if stream_is_done {
324 log::debug!(
325 "MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}"
326 );
327 broker_state.streams.borrow_mut().remove(&stream_id);
328 }
329 })
330 .subscribe();
331 }
332}
333
334struct OrderedItem {
337 seq: u64,
338 shutdown: bool,
339 packet: modrpc::Packet,
340}
341
342impl PartialEq for OrderedItem {
343 fn eq(&self, other: &Self) -> bool {
344 self.seq.eq(&other.seq)
345 }
346}
347
348impl Eq for OrderedItem {}
349
350impl PartialOrd for OrderedItem {
351 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
352 self.seq.partial_cmp(&other.seq)
353 }
354}
355
356impl Ord for OrderedItem {
357 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
358 self.seq.cmp(&other.seq)
359 }
360}
361
362struct StreamState {
365 heap: RefCell<BinaryHeap<Reverse<OrderedItem>>>,
366 first_seq: Cell<u64>,
367 last_seq: Cell<Option<u64>>,
368 received_count: Cell<u64>,
369 next_seq: Cell<Option<u64>>,
370 local_queue_tx: localq::mpsc::Sender<modrpc::Packet>,
371}
372
373impl StreamState {
374 fn new(local_queue_tx: localq::mpsc::Sender<modrpc::Packet>, next_seq: Option<u64>) -> Self {
375 Self {
376 heap: RefCell::new(BinaryHeap::new()),
377 first_seq: Cell::new(0),
378 last_seq: Cell::new(None),
379 received_count: Cell::new(0),
380 next_seq: Cell::new(next_seq),
381 local_queue_tx,
382 }
383 }
384
385 fn try_pop(&self) -> Option<modrpc::Packet> {
386 let mut heap = self.heap.borrow_mut();
387 let Reverse(stream_item) = heap.peek()?;
388
389 let next_seq = self.next_seq.get().unwrap_or_else(|| {
390 self.first_seq.set(stream_item.seq);
391 stream_item.seq
392 });
393
394 if stream_item.seq != next_seq {
395 return None;
396 }
397 self.next_seq.set(Some(next_seq + 1));
398
399 Some(heap.pop().unwrap().0.packet)
400 }
401
402 fn handle_item(&self, seq: u64, shutdown: bool, packet: modrpc::Packet) -> bool {
404 let mut heap = self.heap.borrow_mut();
405
406 let next_seq = self.next_seq.get().unwrap_or_else(|| {
408 self.first_seq.set(seq);
409 seq
410 });
411 if seq < next_seq {
413 return false;
414 }
415
416 heap.push(Reverse(OrderedItem {
418 seq,
419 shutdown,
420 packet,
421 }));
422 self.received_count.set(self.received_count.get() + 1);
423 if shutdown {
424 self.last_seq.set(Some(seq));
425 }
426
427 while let Some(Reverse(stream_item)) = heap.peek() {
428 if stream_item.seq != next_seq {
429 break;
430 }
431
432 let Reverse(stream_item) = heap.pop().unwrap();
434
435 if let Err(localq::mpsc::TrySendError::Full(packet)) =
436 self.local_queue_tx.try_send(stream_item.packet)
437 {
438 heap.push(Reverse(OrderedItem {
439 seq: stream_item.seq,
440 shutdown: stream_item.shutdown,
441 packet,
442 }));
443 break;
444 }
445
446 self.next_seq.set(Some(next_seq + 1));
447 }
448
449 if let Some(last_seq) = self.last_seq.get() {
450 (last_seq - self.first_seq.get() + 1) == self.received_count.get()
451 } else {
452 false
453 }
454 }
455}