1use core::{
2 cell::RefCell,
3 marker::PhantomData,
4};
5use crate::{
6 proto::{
7 MultiStreamInitState,
8 MultiStreamItem,
9 MultiStreamItemLazy,
10 MultiStreamId,
11 MultiStreamReceiverConfig,
12 },
13 receive_stream::{ReceiveStream, StreamState},
14};
15use std::collections::HashMap;
16use std::rc::Rc;
17use modrpc::RoleSetup;
18
19pub enum ReceiveMultiStreamNextError {
20 Shutdown,
21 DecodeItem(mproto::DecodeError),
22}
23
24pub struct ReceiveMultiStream<T> {
25 stream_id: MultiStreamId,
26 receive_stream: ReceiveStream,
27 phantom: PhantomData<T>,
28}
29
30struct BrokerState {
31 streams: RefCell<HashMap<u32, Rc<StreamState>>>,
32}
33
34pub struct MultiStreamReceiver<T> {
35 hooks: crate::MultiStreamReceiverHooks<T>,
36 broker_state: Rc<BrokerState>,
37}
38
39pub struct MultiStreamReceiverBuilder<T> {
40 name: &'static str,
41 hooks: crate::MultiStreamReceiverHooks<T>,
42 stubs: crate::MultiStreamReceiverStubs<T>,
43
44 broker_state: Rc<BrokerState>,
45}
46
47impl<T: mproto::Owned> MultiStreamReceiver<T> {
48 pub fn new_stream(&self, stream_id: MultiStreamId, next_seq: Option<u64>) -> ReceiveMultiStream<T> {
49 let receive_stream = ReceiveStream::new(next_seq);
50 self.broker_state.streams.borrow_mut()
51 .insert(stream_id.id, receive_stream.stream_state().clone());
52
53 ReceiveMultiStream {
54 stream_id,
55 receive_stream,
56 phantom: PhantomData,
57 }
58 }
59}
60
61impl<T> Clone for MultiStreamReceiver<T> {
62 fn clone(&self) -> Self {
63 Self {
64 hooks: self.hooks.clone(),
65 broker_state: self.broker_state.clone(),
66 }
67 }
68}
69
70impl<T: mproto::Owned> ReceiveMultiStream<T> {
71 pub fn id(&self) -> MultiStreamId {
72 self.stream_id
73 }
74
75 pub async fn next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
76 use mproto::BaseLen;
77
78 let packet = self.receive_stream.next_packet().await;
79
80 let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
81 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
82 )
83 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
84
85 let owned_result = stream_item.payload()
86 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
87 .map(|i| T::lazy_to_owned(i))
88 .transpose()
89 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
90
91 Ok(owned_result)
92 }
93
94 pub async fn next_lazy(&mut self)
95 -> Result<mproto::LazyBuf<Option<T>, modrpc::Packet>, ReceiveMultiStreamNextError>
96 {
97 use mproto::BaseLen;
98
99 let packet = self.receive_stream.next_packet().await;
100 packet.advance(modrpc::TransmitPacket::BASE_LEN);
101
102 let stream_item: mproto::LazyBuf<MultiStreamItem<T>, _> = mproto::LazyBuf::new(packet);
103
104 Ok(stream_item.map(|s| s.payload().unwrap()))
105 }
106
107 pub fn try_next(&mut self) -> Result<Option<T>, ReceiveMultiStreamNextError> {
108 use mproto::BaseLen;
109
110 let Some(packet) = self.receive_stream.try_next_packet() else {
111 return Ok(None);
112 };
113
114 let stream_item: MultiStreamItemLazy<T> = mproto::decode_value(
115 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
116 )
117 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
118
119 let owned_result = stream_item.payload()
120 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?
121 .map(|i| T::lazy_to_owned(i))
122 .transpose()
123 .map_err(|e| ReceiveMultiStreamNextError::DecodeItem(e))?;
124
125 Ok(owned_result)
126 }
127
128 pub fn with_try_next<R>(
129 &mut self,
130 f: impl FnOnce(Option<mproto::DecodeResult<Option<T::Lazy<'_>>>>) -> R,
131 ) -> R {
132 use mproto::BaseLen;
133
134 let Some(packet) = self.receive_stream.try_next_packet() else {
135 return f(None);
136 };
137
138 let stream_item =
139 match mproto::decode_value::<MultiStreamItemLazy<T>>(
140 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
141 ) {
142 Ok(x) => x,
143 Err(e) => {
144 return f(Some(Err(e)));
145 },
146 };
147
148 let payload =
149 match stream_item.payload() {
150 Ok(x) => x,
151 Err(e) => {
152 return f(Some(Err(e)));
153 },
154 };
155
156 f(Some(Ok(payload)))
157 }
158
159 pub async fn with_next<'a, Fut, R>(
160 &mut self,
161 f: impl FnOnce(mproto::DecodeResult<Option<T::Lazy<'_>>>) -> Fut,
162 ) -> Option<R>
163 where Fut: std::future::Future<Output = R>
164 {
165 use mproto::BaseLen;
166
167 let packet = self.receive_stream.next_packet().await;
168
169 let stream_item =
170 match mproto::decode_value::<MultiStreamItemLazy<T>>(
171 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
172 ) {
173 Ok(x) => x,
174 Err(e) => {
175 return Some(f(Err(e)).await);
176 },
177 };
178
179 let payload =
180 match stream_item.payload() {
181 Ok(x) => x,
182 Err(e) => {
183 return Some(f(Err(e)).await);
184 },
185 };
186
187 Some(f(Ok(payload)).await)
188 }
189
190 pub async fn with_next_sync<'a, R>(
191 &mut self,
192 f: impl FnOnce(mproto::DecodeResult<T::Lazy<'_>>) -> R,
193 ) -> Option<R> {
194 use mproto::BaseLen;
195
196 let packet = self.receive_stream.next_packet().await;
197
198 let stream_item =
199 match mproto::decode_value::<MultiStreamItemLazy<T>>(
200 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
201 ) {
202 Ok(x) => x,
203 Err(e) => {
204 return Some(f(Err(e)));
205 }
206 };
207
208 let payload =
209 match stream_item.payload() {
210 Ok(Some(x)) => x,
211 Ok(None) => {
212 return None;
214 }
215 Err(e) => {
216 return Some(f(Err(e)));
217 }
218 };
219
220 Some(f(Ok(payload)))
221 }
222
223 pub async fn collect(&mut self) -> Result<Vec<T>, ReceiveMultiStreamNextError> {
224 let mut collected = Vec::new();
225 while let Some(item) = self.next().await? {
226 collected.push(item);
227 }
228 Ok(collected)
229 }
230}
231
232#[derive(thiserror::Error, Debug)]
233pub enum MultiStreamTryCollectError<E: std::fmt::Debug> {
234 #[error("failed to decode MultiStream item")]
235 DecodeError(#[from] mproto::DecodeError),
236 #[error("stream sender failed: {0:?}")]
237 SenderError(E),
238 #[error("plane is shutting down")]
239 Shutdown,
240}
241
242impl<E: std::fmt::Debug> From<ReceiveMultiStreamNextError> for MultiStreamTryCollectError<E> {
243 fn from(other: ReceiveMultiStreamNextError) -> Self {
244 match other {
245 ReceiveMultiStreamNextError::DecodeItem(e) =>
246 MultiStreamTryCollectError::DecodeError(e),
247 ReceiveMultiStreamNextError::Shutdown =>
248 MultiStreamTryCollectError::Shutdown,
249 }
250 }
251}
252
253impl<T: mproto::Owned, E: mproto::Owned + std::fmt::Debug> ReceiveMultiStream<Result<T, E>> {
254 pub async fn try_collect(&mut self) -> Result<Vec<T>, MultiStreamTryCollectError<E>> {
255 let mut collected = Vec::new();
256 while let Some(item) =
257 self.next().await?
258 .transpose()
259 .map_err(|e| MultiStreamTryCollectError::SenderError(e))?
260 {
261 collected.push(item);
262 }
263 Ok(collected)
264 }
265}
266
267impl<T: mproto::Owned> MultiStreamReceiverBuilder<T> {
268 pub fn new(
269 name: &'static str,
270 hooks: crate::MultiStreamReceiverHooks<T>,
271 stubs: crate::MultiStreamReceiverStubs<T>,
272 _config: &MultiStreamReceiverConfig,
273 _init: MultiStreamInitState,
274 ) -> Self {
275 Self {
276 name, hooks, stubs,
277 broker_state: Rc::new(BrokerState {
278 streams: RefCell::new(HashMap::new()),
279 }),
280 }
281 }
282
283 pub fn create_handle(
284 &self,
285 _setup: &RoleSetup,
286 ) -> crate::MultiStreamReceiver<T> {
287 crate::MultiStreamReceiver {
288 hooks: self.hooks.clone(),
289 broker_state: self.broker_state.clone(),
290 }
291 }
292
293 pub fn build(
294 self,
295 setup: &RoleSetup,
296 ) {
297 use mproto::BaseLen;
298
299 let broker_state = self.broker_state;
300 self.stubs.item.inline_untyped(setup, move |_source, packet| {
301 let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
302 let (seq, stream_id, shutdown) = {
303 let Ok(stream_item) =
304 mproto::decode_value::<MultiStreamItemLazy<T>>(stream_item_bytes)
305 else {
306 return;
307 };
308 let Ok(seq) = stream_item.seq() else {
309 return;
310 };
311 let Ok(stream_id) = stream_item.stream_id().and_then(|r| r.id()) else {
312 return;
313 };
314 let Ok(payload) = stream_item.payload() else {
315 return;
316 };
317 (seq, stream_id, payload.is_none())
318 };
319
320 let Some(stream_state) = broker_state.streams.borrow().get(&stream_id).cloned() else {
321 log::warn!("Unknown stream_id name={} stream_id={stream_id}", self.name);
322 return;
323 };
324
325 let stream_is_done = stream_state.handle_item(seq, shutdown, packet.clone());
326 if stream_is_done {
327 log::debug!("MultiStreamReciever shutdown stream stream_id={stream_id} seq={seq}");
328 broker_state.streams.borrow_mut().remove(&stream_id);
329 }
330 })
331 .subscribe();
332 }
333}
334