std_modrpc/role_impls/
stream_receiver.rs

1use std::{
2    cell::RefCell,
3    rc::Rc,
4};
5use modrpc::RoleSetup;
6
7use crate::{
8    proto::{StreamInitState, StreamItem, StreamItemLazy, StreamReceiverConfig},
9    receive_stream::{ReceiveStream, StreamState},
10};
11
12#[derive(Clone)]
13pub struct StreamReceiver<T> {
14    subscriptions: Rc<Subscriptions>,
15    _phantom: core::marker::PhantomData<T>,
16}
17
18impl<T: mproto::Owned> StreamReceiver<T> {
19    pub fn subscribe(&self, next_seq: Option<u64>) -> StreamSubscription<T> {
20        let receive_stream = ReceiveStream::new(next_seq);
21        self.subscriptions.stream_states.borrow_mut().push(receive_stream.stream_state().clone());
22        StreamSubscription {
23            receive_stream,
24            subscriptions: self.subscriptions.clone(),
25            _phantom: core::marker::PhantomData,
26        }
27    }
28}
29
30pub struct StreamSubscription<T> {
31    receive_stream: ReceiveStream,
32    subscriptions: Rc<Subscriptions>,
33    _phantom: core::marker::PhantomData<T>,
34}
35
36impl<T> Drop for StreamSubscription<T> {
37    fn drop(&mut self) {
38        self.subscriptions.stream_states.borrow_mut()
39            .retain(|s| Rc::as_ptr(s) != Rc::as_ptr(self.receive_stream.stream_state()));
40    }
41}
42
43impl<T: mproto::Owned> StreamSubscription<T> {
44    pub async fn next(&mut self) -> mproto::DecodeResult<T> {
45        use mproto::BaseLen;
46
47        let packet = self.receive_stream.next_packet().await;
48
49        let stream_item: StreamItemLazy<T> = mproto::decode_value(
50            &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
51        )?;
52        let owned_result = stream_item.payload().map(|i| T::lazy_to_owned(i))??;
53
54        Ok(owned_result)
55    }
56
57    pub async fn next_lazy(&mut self)
58        -> mproto::DecodeResult<mproto::LazyBuf<T, modrpc::Packet>>
59    {
60        use mproto::BaseLen;
61
62        let packet = self.receive_stream.next_packet().await;
63        packet.advance(modrpc::TransmitPacket::BASE_LEN);
64
65        let stream_item: mproto::LazyBuf<StreamItem<T>, _> = mproto::LazyBuf::new(packet);
66        // TODO LazyBuf::try_map
67        let payload = stream_item.map(|s| s.payload().unwrap());
68
69        Ok(payload)
70    }
71
72    pub fn try_next(&mut self) -> mproto::DecodeResult<Option<T>> {
73        use mproto::BaseLen;
74
75        let Some(packet) = self.receive_stream.try_next_packet() else {
76            return Ok(None);
77        };
78        packet.advance(modrpc::TransmitPacket::BASE_LEN);
79
80        let stream_item: StreamItemLazy<T> = mproto::decode_value(&packet)?;
81        let payload = stream_item.payload().and_then(|i| T::lazy_to_owned(i))?;
82
83        Ok(Some(payload))
84    }
85
86    pub fn try_next_lazy(&mut self)
87        -> mproto::DecodeResult<Option<mproto::LazyBuf<T, modrpc::Packet>>>
88    {
89        use mproto::BaseLen;
90
91        let Some(packet) = self.receive_stream.try_next_packet() else {
92            return Ok(None);
93        };
94        packet.advance(modrpc::TransmitPacket::BASE_LEN);
95
96        let stream_item: mproto::LazyBuf<StreamItem<T>, _> = mproto::LazyBuf::new(packet);
97        // TODO LazyBuf::try_map
98        let payload = stream_item.map(|s| s.payload().unwrap());
99
100        Ok(Some(payload))
101    }
102}
103
104struct Subscriptions {
105    stream_states: RefCell<Vec<Rc<StreamState>>>,
106}
107
108pub struct StreamReceiverBuilder<T> {
109    stubs: crate::StreamReceiverStubs<T>,
110    subscriptions: Rc<Subscriptions>,
111}
112
113impl<T: mproto::Owned> StreamReceiverBuilder<T> {
114    pub fn new(
115        _name: &'static str,
116        _hooks: crate::StreamReceiverHooks<T>,
117        stubs: crate::StreamReceiverStubs<T>,
118        _config: &StreamReceiverConfig,
119        _init: StreamInitState,
120    ) -> Self {
121        Self {
122            stubs,
123            subscriptions: Rc::new(Subscriptions {
124                stream_states: RefCell::new(Vec::new()),
125            }),
126        }
127    }
128
129    pub fn create_handle(
130        &self,
131        _setup: &RoleSetup,
132    ) -> crate::StreamReceiver<T> {
133        crate::StreamReceiver {
134            subscriptions: self.subscriptions.clone(),
135            _phantom: core::marker::PhantomData,
136        }
137    }
138
139    pub fn build(
140        self,
141        setup: &RoleSetup,
142    ) {
143        use mproto::BaseLen;
144
145        let subscriptions = self.subscriptions;
146        self.stubs.item.inline_untyped(setup, move |_source, packet| {
147            let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
148            let Ok(stream_item) =
149                mproto::decode_value::<StreamItemLazy<T>>(stream_item_bytes)
150            else {
151                return;
152            };
153            let Ok(seq) = stream_item.seq() else {
154                return;
155            };
156
157            for stream_state in &mut *subscriptions.stream_states.borrow_mut() {
158                let _stream_is_done = stream_state.handle_item(seq, false, packet.clone());
159            }
160        })
161        .subscribe();
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use modrpc_executor::ModrpcExecutor;
168    use crate::{
169        StreamInitState,
170        StreamSenderBuilder,
171        StreamSenderConfig,
172        StreamSenderRole,
173        StreamReceiverConfig,
174        StreamReceiverRole,
175    };
176    use super::*;
177
178    #[test]
179    fn test_stream_receiver() {
180        let mut ex = modrpc_executor::FuturesExecutor::new();
181        let (rt, _rt_shutdown) = modrpc::RuntimeHandle::single_threaded(&mut ex);
182
183        ex.run_until(async move {
184            let transport = rt.add_transport(modrpc::LocalTransport {
185                buffer_size: 256,
186                buffer_pool_batches: 16,
187                buffer_pool_batch_size: 16,
188            })
189            .await;
190
191            let mut stream_sender = None;
192            let _ =
193                rt.start_role::<StreamSenderRole<String>>(modrpc::RoleConfig {
194                    plane_id: 0,
195                    endpoint_addr: modrpc::EndpointAddr { endpoint: 0 },
196                    transport: transport.clone(),
197                    topic_channels: modrpc::TopicChannels::SingleChannel { channel_id: 0 },
198                    config: StreamSenderConfig { },
199                    init: StreamInitState { },
200                })
201                .local(|cx| {
202                    let builder = StreamSenderBuilder::new("stream_sender", cx.hooks.clone(), cx.stubs, cx.config, cx.init.clone());
203                    stream_sender = Some(builder.create_handle(cx.setup));
204                    builder.build(cx.setup);
205                });
206
207            let mut stream_receiver = None;
208            let _ =
209                rt.start_role::<StreamReceiverRole<String>>(modrpc::RoleConfig {
210                    plane_id: 0,
211                    endpoint_addr: modrpc::EndpointAddr { endpoint: 0 },
212                    transport: transport,
213                    topic_channels: modrpc::TopicChannels::SingleChannel { channel_id: 0 },
214                    config: StreamReceiverConfig { },
215                    init: StreamInitState { },
216                })
217                .local(|cx| {
218                    let builder = StreamReceiverBuilder::new("stream_receiver", cx.hooks.clone(), cx.stubs, cx.config, cx.init.clone());
219                    stream_receiver = Some(builder.create_handle(cx.setup));
220                    builder.build(cx.setup);
221                });
222
223            let stream_sender = stream_sender.unwrap();
224            let stream_receiver = stream_receiver.unwrap();
225
226            stream_sender.send("asdf").await;
227
228            // Passing None to subscribe will make it accept the first seq it sees as the next seq.
229            let mut subscription = stream_receiver.subscribe(None);
230
231            assert!(matches!(subscription.try_next(), Ok(None)));
232
233            stream_sender.send("foo").await;
234            stream_sender.send("bar").await;
235            stream_sender.send("baz").await;
236
237            assert_eq!(subscription.next().await.unwrap(), "foo");
238            assert_eq!(subscription.next().await.unwrap(), "bar");
239            assert_eq!(subscription.next().await.unwrap(), "baz");
240
241            assert!(matches!(subscription.try_next(), Ok(None)));
242
243            // Test StreamSubscription::drop
244            let subscriptions = stream_receiver.subscriptions.clone();
245            let subscription2 = stream_receiver.subscribe(None);
246            assert_eq!(subscriptions.stream_states.borrow().len(), 2);
247            drop(subscription);
248            assert_eq!(subscriptions.stream_states.borrow().len(), 1);
249            drop(subscription2);
250            assert_eq!(subscriptions.stream_states.borrow().len(), 0);
251        });
252    }
253}