1use core::cell::{Cell, RefCell};
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4use crate::proto::{
5 MultiStreamInitState,
6 MultiStreamItem,
7 MultiStreamItemLazy,
8 MultiStreamId,
9 MultiStreamReceiverConfig,
10};
11use std::collections::{BinaryHeap, HashMap};
12use std::rc::Rc;
13use modrpc::RoleSetup;
14
15pub enum ReceiveMultiStreamNextError {
16 Shutdown,
17 DecodeItem(mproto::DecodeError),
18}
19
20pub struct ReceiveMultiStream<T> {
21 stream_id: MultiStreamId,
22 local_queue_rx: localq::mpsc::Receiver<modrpc::Packet>,
23 stream_state: Rc<StreamState>,
24 phantom: PhantomData<T>,
25}
26
27struct BrokerState {
28 streams: RefCell<HashMap<u32, Rc<StreamState>>>,
29}
30
31pub struct MultiStreamReceiver<T> {
32 hooks: crate::MultiStreamReceiverHooks<T>,
33 broker_state: Rc<BrokerState>,
34}
35
36pub struct MultiStreamReceiverBuilder<T> {
37 name: &'static str,
38 hooks: crate::MultiStreamReceiverHooks<T>,
39 stubs: crate::MultiStreamReceiverStubs<T>,
40
41 broker_state: Rc<BrokerState>,
42}
43
44impl<T: mproto::Owned> MultiStreamReceiver<T> {
45 pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
46 let (local_queue_tx, local_queue_rx) = localq::mpsc::channel(1);
48 let stream_state = Rc::new(StreamState::new(local_queue_tx, next_seq));
49 self.broker_state.streams.borrow_mut().insert(stream_id.id, stream_state.clone());
50
51 ReceiveMultiStream {
52 stream_id,
53 local_queue_rx,
54 stream_state,
55 phantom: PhantomData,
56 }
57 }
58}
59
60impl<T> Clone for MultiStreamReceiver<T> {
61 fn clone(&self) -> Self {
62 Self {
63 hooks: self.hooks.clone(),
64 broker_state: self.broker_state.clone(),
65 }
66 }
67}
68
69impl<T: mproto::Owned> ReceiveMultiStream<T> {
70 pub fn id(&self) -> MultiStreamId {
71 self.stream_id
72 }
73
74 fn try_next_packet(&mut self) -> Option<modrpc::Packet> {
75 if let Ok(packet) = self.local_queue_rx.try_recv() {
76 return Some(packet);
77 }
78
79 self.stream_state.try_pop()
80 }
81
82 async fn next_packet(&mut self) -> modrpc::Packet {
83 if let Ok(packet) = self.local_queue_rx.try_recv() {
84 return packet;
85 }
86
87 if let Some(packet) = self.stream_state.try_pop() {
88 return packet;
89 }
90
91 self.local_queue_rx.recv().await.unwrap()
92 }
93
94 pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
95 use mproto::BaseLen;
96
97 let packet = self.next_packet().await;
98
99 let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
100 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
101 )
102 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
103
104 let owned_result = stream_item.payload()
105 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
106 .map(|i| T::lazy_to_owned(i))
107 .transpose()
108 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
109
110 Ok(owned_result)
111 }
112
113 pub async fn next_lazy(&mut self)
114 -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
115 {
116 use mproto::BaseLen;
117
118 let packet = self.next_packet().await;
119 packet.advance(modrpc::TransmitPacket::BASE_LEN);
120
121 let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
122
123 Ok(stream_item.map(|s| s.payload().unwrap()))
124 }
125
126 pub fn with_try_next<R>(
127 &mut self,
128 f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
129 ) -> R {
130 use mproto::BaseLen;
131
132 let Some(packet) = self.try_next_packet() else {
133 return f(None);
134 };
135
136 let stream_item =
137 match mproto::decode_value::<MultiStreamItemLazy<T>>(
138 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
139 ) {
140 Ok(x) => x,
141 Err(e) => {
142 return f(Some(Err(e)));
143 },
144 };
145
146 let payload =
147 match stream_item.payload() {
148 Ok(x) => x,
149 Err(e) => {
150 return f(Some(Err(e)));
151 },
152 };
153
154 f(Some(Ok(payload)))
155 }
156
157 pub async fn with_next<'a, Fut, R>(
158 &mut self,
159 f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
160 ) -> Option<R>
161 where Fut: std::future::Future<Output = R>
162 {
163 use mproto::BaseLen;
164
165 let packet = self.next_packet().await;
166
167 let stream_item =
168 match mproto::decode_value::<MultiStreamItemLazy<T>>(
169 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
170 ) {
171 Ok(x) => x,
172 Err(e) => {
173 return Some(f(Err(e)).await);
174 },
175 };
176
177 let payload =
178 match stream_item.payload() {
179 Ok(x) => x,
180 Err(e) => {
181 return Some(f(Err(e)).await);
182 },
183 };
184
185 Some(f(Ok(payload)).await)
186 }
187
188 pub async fn with_next_sync<'a, R>(
189 &mut self,
190 f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
191 ) -> Option<R> {
192 use mproto::BaseLen;
193
194 let packet = self.next_packet().await;
195
196 let stream_item =
197 match mproto::decode_value::<MultiStreamItemLazy<T>>(
198 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
199 ) {
200 Ok(x) => x,
201 Err(e) => {
202 return Some(f(Err(e)));
203 }
204 };
205
206 let payload =
207 match stream_item.payload() {
208 Ok(Some(x)) => x,
209 Ok(None) => {
210 return None;
212 }
213 Err(e) => {
214 return Some(f(Err(e)));
215 }
216 };
217
218 Some(f(Ok(payload)))
219 }
220
221 pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
222 let mut collected = Vec::new();
223 while let Some(item) = self.next().await? {
224 collected.push(item);
225 }
226 Ok(collected)
227 }
228}
229
230#[derive(thiserror::Error, Debug)]
231pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
232 #[error("failed to decode MultiStream item")]
233 DecodeError(#[from] mproto::DecodeError),
234 #[error("stream sender failed: {0:?}")]
235 SenderError(E),
236 #[error("plane is shutting down")]
237 Shutdown,
238}
239
240impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
241 fn from(other: ReceiveMultiStreamNextError) -> Self {
242 match other {
243 ReceiveMultiStreamNextError::DecodeItem(e) =>
244 MultiStreamTryCollectError::DecodeError(e),
245 ReceiveMultiStreamNextError::Shutdown =>
246 MultiStreamTryCollectError::Shutdown,
247 }
248 }
249}
250
251impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
252 pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
253 let mut collected = Vec::new();
254 while let Some(item) =
255 self.next().await?
256 .transpose()
257 .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
258 {
259 collected.push(item);
260 }
261 Ok(collected)
262 }
263}
264
265impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
266 pub fn new(
267 name: &'static str,
268 hooks: crate::MultiStreamReceiverHooks<T>,
269 stubs: crate::MultiStreamReceiverStubs<T>,
270 _config: &MultiStreamReceiverConfig,
271 _init: MultiStreamInitState,
272 ) -> Self {
273 Self {
274 name, hooks, stubs,
275 broker_state: Rc::new(BrokerState {
276 streams: RefCell::new(HashMap::new()),
277 }),
278 }
279 }
280
281 pub fn create_handle(
282 &self,
283 _setup: &RoleSetup,
284 ) -> crate::MultiStreamReceiver<T> {
285 crate::MultiStreamReceiver {
286 hooks: self.hooks.clone(),
287 broker_state: self.broker_state.clone(),
288 }
289 }
290
291 pub fn build(
292 self,
293 setup: &RoleSetup,
294 ) {
295 use mproto::BaseLen;
296
297 let broker_state = self.broker_state;
298 self.stubs.item.inline_untyped(setup, move |_source, packet| {
299 let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
300 let (seq, stream_id, shutdown) = {
301 let Ok(stream_item) =
302 mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
303 else {
304 return;
305 };
306 let Ok(seq) = stream_item.seq() else {
307 return;
308 };
309 let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
310 return;
311 };
312 let Ok(payload) = stream_item.payload() else {
313 return;
314 };
315 (seq, stream_id, payload.is_none())
316 };
317
318 let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
319 log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
320 return;
321 };
322
323 let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
324 if stream_is_done {
325 log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
326 broker_state.streams.borrow_mut().remove(&stream_id);
327 }
328 })
329 .subscribe();
330 }
331}
332
333struct OrderedItem {
336 seq: u64,
337 shutdown: bool,
338 packet: modrpc::Packet,
339}
340
341impl PartialEq for OrderedItem {
342 fn eq(&self, other: &Self) -> bool { self.seq.eq(&other.seq) }
343}
344
345impl Eq for OrderedItem { }
346
347impl PartialOrd for OrderedItem {
348 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
349 self.seq.partial_cmp(&other.seq)
350 }
351}
352
353impl Ord for OrderedItem {
354 fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.seq.cmp(&other.seq) }
355}
356
357struct StreamState {
360 heap: RefCell<BinaryHeap<Reverse<OrderedItem>>>,
361 first_seq: Cell<u64>,
362 last_seq: Cell<Option<u64>>,
363 received_count: Cell<u64>,
364 next_seq: Cell<Option<u64>>,
365 local_queue_tx: localq::mpsc::Sender<modrpc::Packet>,
366}
367
368impl StreamState {
369 fn new(local_queue_tx: localq::mpsc::Sender<modrpc::Packet>, next_seq: Option<u64>) -> Self {
370 Self {
371 heap: RefCell::new(BinaryHeap::new()),
372 first_seq: Cell::new(0),
373 last_seq: Cell::new(None),
374 received_count: Cell::new(0),
375 next_seq: Cell::new(next_seq),
376 local_queue_tx,
377 }
378 }
379
380 fn try_pop(&self) -> Option<modrpc::Packet> {
381 let mut heap = self.heap.borrow_mut();
382 let Reverse(stream_item) = heap.peek()?;
383
384 let next_seq = self.next_seq.get().unwrap_or_else(|| {
385 self.first_seq.set(stream_item.seq);
386 stream_item.seq
387 });
388
389 if stream_item.seq != next_seq {
390 return None;
391 }
392 self.next_seq.set(Some(next_seq + 1));
393
394 Some(heap.pop().unwrap().0.packet)
395 }
396
397 fn handle_item(&self, seq: u64, shutdown: bool, packet: modrpc::Packet) -> bool {
399 let mut heap = self.heap.borrow_mut();
400
401 let next_seq = self.next_seq.get().unwrap_or_else(|| {
403 self.first_seq.set(seq);
404 seq
405 });
406 if seq < next_seq {
408 return false;
409 }
410
411 heap.push(Reverse(OrderedItem { seq, shutdown, packet }));
413 self.received_count.set(self.received_count.get() + 1);
414 if shutdown {
415 self.last_seq.set(Some(seq));
416 }
417
418 while let Some(Reverse(stream_item)) = heap.peek() {
419 if stream_item.seq != next_seq { break; }
420
421 let Reverse(stream_item) = heap.pop().unwrap();
423
424 if let Err(localq::mpsc::TrySendError::Full(packet)) =
425 self.local_queue_tx.try_send(stream_item.packet)
426 {
427 heap.push(Reverse(OrderedItem {
428 seq: stream_item.seq,
429 shutdown: stream_item.shutdown,
430 packet,
431 }));
432 break;
433 }
434
435 self.next_seq.set(Some(next_seq + 1));
436 }
437
438 if let Some(last_seq) = self.last_seq.get() {
439 (last_seq - self.first_seq.get() + 1) == self.received_count.get()
440 } else {
441 false
442 }
443 }
444}
445