1use core::{
2 cell::RefCell,
3 marker::PhantomData,
4};
5use crate::{
6 proto::{
7 MultiStreamInitState,
8 MultiStreamItem,
9 MultiStreamItemLazy,
10 MultiStreamId,
11 MultiStreamReceiverConfig,
12 },
13 receive_stream::{ReceiveStream, StreamState},
14};
15use std::collections::HashMap;
16use std::rc::Rc;
17use modrpc::RoleSetup;
18
19pub enum ReceiveMultiStreamNextError {
20 Shutdown,
21 DecodeItem(mproto::DecodeError),
22}
23
24pub struct ReceiveMultiStream<T> {
25 stream_id: MultiStreamId,
26 receive_stream: ReceiveStream,
27 phantom: PhantomData<T>,
28}
29
30struct BrokerState {
31 streams: RefCell<HashMap<u32, Rc<StreamState>>>,
32}
33
34pub struct MultiStreamReceiver<T> {
35 hooks: crate::MultiStreamReceiverHooks<T>,
36 broker_state: Rc<BrokerState>,
37}
38
39pub struct MultiStreamReceiverBuilder<T> {
40 name: &'static str,
41 hooks: crate::MultiStreamReceiverHooks<T>,
42 stubs: crate::MultiStreamReceiverStubs<T>,
43
44 broker_state: Rc<BrokerState>,
45}
46
47impl<T: mproto::Owned> MultiStreamReceiver<T> {
48 pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
49 let receive_stream = ReceiveStream::new(next_seq);
50 self.broker_state.streams.borrow_mut()
51 .insert(stream_id.id, receive_stream.stream_state().clone());
52
53 ReceiveMultiStream {
54 stream_id,
55 receive_stream,
56 phantom: PhantomData,
57 }
58 }
59}
60
61impl<T> Clone for MultiStreamReceiver<T> {
62 fn clone(&self) -> Self {
63 Self {
64 hooks: self.hooks.clone(),
65 broker_state: self.broker_state.clone(),
66 }
67 }
68}
69
70impl<T: mproto::Owned> ReceiveMultiStream<T> {
71 pub fn id(&self) -> MultiStreamId {
72 self.stream_id
73 }
74
75 pub async fn next(&mut self) -> Result<T, ReceiveMultiStreamNextError> {
76 use mproto::BaseLen;
77
78 let packet = self.receive_stream.next_packet().await;
79
80 let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
81 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
82 )
83 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
84
85 let owned_result = T::lazy_to_owned(
86 stream_item.payload()
87 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
88 .ok_or(ReceiveMultiStreamNextError::Shutdown)?
89 )
90 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
91
92 Ok(owned_result)
93 }
94
95 pub async fn next_lazy(&mut self)
96 -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
97 {
98 use mproto::BaseLen;
99
100 let packet = self.receive_stream.next_packet().await;
101 packet.advance(modrpc::TransmitPacket::BASE_LEN);
102
103 let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
104
105 Ok(stream_item.map(|s| s.payload().unwrap()))
106 }
107
108 pub fn try_next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
109 use mproto::BaseLen;
110
111 let Some(packet) = self.receive_stream.try_next_packet() else {
112 return Ok(None);
113 };
114
115 let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
116 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
117 )
118 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
119
120 let owned_result = T::lazy_to_owned(
121 stream_item.payload()
122 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
123 .ok_or(ReceiveMultiStreamNextError::Shutdown)?
124 )
125 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
126
127 Ok(Some(owned_result))
128 }
129
130 pub fn with_try_next<R>(
131 &mut self,
132 f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
133 ) -> R {
134 use mproto::BaseLen;
135
136 let Some(packet) = self.receive_stream.try_next_packet() else {
137 return f(None);
138 };
139
140 let stream_item =
141 match mproto::decode_value::<MultiStreamItemLazy<T>>(
142 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
143 ) {
144 Ok(x) => x,
145 Err(e) => {
146 return f(Some(Err(e)));
147 },
148 };
149
150 let payload =
151 match stream_item.payload() {
152 Ok(x) => x,
153 Err(e) => {
154 return f(Some(Err(e)));
155 },
156 };
157
158 f(Some(Ok(payload)))
159 }
160
161 pub async fn with_next<'a, Fut, R>(
162 &mut self,
163 f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
164 ) -> Option<R>
165 where Fut: std::future::Future<Output = R>
166 {
167 use mproto::BaseLen;
168
169 let packet = self.receive_stream.next_packet().await;
170
171 let stream_item =
172 match mproto::decode_value::<MultiStreamItemLazy<T>>(
173 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
174 ) {
175 Ok(x) => x,
176 Err(e) => {
177 return Some(f(Err(e)).await);
178 },
179 };
180
181 let payload =
182 match stream_item.payload() {
183 Ok(x) => x,
184 Err(e) => {
185 return Some(f(Err(e)).await);
186 },
187 };
188
189 Some(f(Ok(payload)).await)
190 }
191
192 pub async fn with_next_sync<'a, R>(
193 &mut self,
194 f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
195 ) -> Option<R> {
196 use mproto::BaseLen;
197
198 let packet = self.receive_stream.next_packet().await;
199
200 let stream_item =
201 match mproto::decode_value::<MultiStreamItemLazy<T>>(
202 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
203 ) {
204 Ok(x) => x,
205 Err(e) => {
206 return Some(f(Err(e)));
207 }
208 };
209
210 let payload =
211 match stream_item.payload() {
212 Ok(Some(x)) => x,
213 Ok(None) => {
214 return None;
216 }
217 Err(e) => {
218 return Some(f(Err(e)));
219 }
220 };
221
222 Some(f(Ok(payload)))
223 }
224
225 pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
226 let mut collected = Vec::new();
227 loop {
228 match self.next().await {
229 Ok(item) => { collected.push(item); Ok(()) }
230 Err(ReceiveMultiStreamNextError::Shutdown) => break,
231 Err(e) => Err(e),
232 }?;
233 }
234 Ok(collected)
235 }
236}
237
238#[derive(thiserror::Error, Debug)]
239pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
240 #[error("failed to decode MultiStream item")]
241 DecodeError(#[from] mproto::DecodeError),
242 #[error("stream sender failed: {0:?}")]
243 SenderError(E),
244 #[error("plane is shutting down")]
245 Shutdown,
246}
247
248impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
249 fn from(other: ReceiveMultiStreamNextError) -> Self {
250 match other {
251 ReceiveMultiStreamNextError::DecodeItem(e) =>
252 MultiStreamTryCollectError::DecodeError(e),
253 ReceiveMultiStreamNextError::Shutdown =>
254 MultiStreamTryCollectError::Shutdown,
255 }
256 }
257}
258
259impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
260 pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
261 let mut collected = Vec::new();
262 loop {
263 match self.next().await {
264 Ok(Ok(item)) => { collected.push(item); Ok(()) }
265 Ok(Err(e)) => return Err(MultiStreamTryCollectError::SenderError(e.into())),
266 Err(ReceiveMultiStreamNextError::Shutdown) => break,
267 Err(e) => Err(e),
268 }?;
269 }
270 Ok(collected)
271 }
272}
273
274impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
275 pub fn new(
276 name: &'static str,
277 hooks: crate::MultiStreamReceiverHooks<T>,
278 stubs: crate::MultiStreamReceiverStubs<T>,
279 _config: &MultiStreamReceiverConfig,
280 _init: MultiStreamInitState,
281 ) -> Self {
282 Self {
283 name, hooks, stubs,
284 broker_state: Rc::new(BrokerState {
285 streams: RefCell::new(HashMap::new()),
286 }),
287 }
288 }
289
290 pub fn create_handle(
291 &self,
292 _setup: &RoleSetup,
293 ) -> crate::MultiStreamReceiver<T> {
294 crate::MultiStreamReceiver {
295 hooks: self.hooks.clone(),
296 broker_state: self.broker_state.clone(),
297 }
298 }
299
300 pub fn build(
301 self,
302 setup: &RoleSetup,
303 ) {
304 use mproto::BaseLen;
305
306 let broker_state = self.broker_state;
307 self.stubs.item.inline_untyped(setup, move |_source, packet| {
308 let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
309 let (seq, stream_id, shutdown) = {
310 let Ok(stream_item) =
311 mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
312 else {
313 return;
314 };
315 let Ok(seq) = stream_item.seq() else {
316 return;
317 };
318 let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
319 return;
320 };
321 let Ok(payload) = stream_item.payload() else {
322 return;
323 };
324 (seq, stream_id, payload.is_none())
325 };
326
327 let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
328 log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
329 return;
330 };
331
332 let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
333 if stream_is_done {
334 log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
335 broker_state.streams.borrow_mut().remove(&stream_id);
336 }
337 })
338 .subscribe();
339 }
340}
341