std_modrpc/role_impls/
stream_receiver.rs1use std::{
2 cell::RefCell,
3 rc::Rc,
4};
5use modrpc::RoleSetup;
6
7use crate::{
8 proto::{StreamInitState, StreamItem, StreamItemLazy, StreamReceiverConfig},
9 receive_stream::{ReceiveStream, StreamState},
10};
11
12#[derive(Clone)]
13pub struct StreamReceiver<T> {
14 subscriptions: Rc<Subscriptions>,
15 _phantom: core::marker::PhantomData<T>,
16}
17
18impl<T: mproto::Owned> StreamReceiver<T> {
19 pub fn subscribe(&self, next_seq: Option<u64>) -> StreamSubscription<T> {
20 let receive_stream = ReceiveStream::new(next_seq);
21 self.subscriptions.stream_states.borrow_mut().push(receive_stream.stream_state().clone());
22 StreamSubscription {
23 receive_stream,
24 subscriptions: self.subscriptions.clone(),
25 _phantom: core::marker::PhantomData,
26 }
27 }
28}
29
30pub struct StreamSubscription<T> {
31 receive_stream: ReceiveStream,
32 subscriptions: Rc<Subscriptions>,
33 _phantom: core::marker::PhantomData<T>,
34}
35
36impl<T> Drop for StreamSubscription<T> {
37 fn drop(&mut self) {
38 self.subscriptions.stream_states.borrow_mut()
39 .retain(|s| Rc::as_ptr(s) != Rc::as_ptr(self.receive_stream.stream_state()));
40 }
41}
42
43impl<T: mproto::Owned> StreamSubscription<T> {
44 pub async fn next(&mut self) -> mproto::DecodeResult<T> {
45 use mproto::BaseLen;
46
47 let packet = self.receive_stream.next_packet().await;
48
49 let stream_item: StreamItemLazy<T> = mproto::decode_value(
50 &packet.as_ref()[modrpc::TransmitPacket::BASE_LEN..]
51 )?;
52 let owned_result = stream_item.payload().map(|i| T::lazy_to_owned(i))??;
53
54 Ok(owned_result)
55 }
56
57 pub async fn next_lazy(&mut self)
58 -> mproto::DecodeResult<mproto::LazyBuf<T, modrpc::Packet>>
59 {
60 use mproto::BaseLen;
61
62 let packet = self.receive_stream.next_packet().await;
63 packet.advance(modrpc::TransmitPacket::BASE_LEN);
64
65 let stream_item: mproto::LazyBuf<StreamItem<T>, _> = mproto::LazyBuf::new(packet);
66 let payload = stream_item.map(|s| s.payload().unwrap());
68
69 Ok(payload)
70 }
71
72 pub fn try_next(&mut self) -> mproto::DecodeResult<Option<T>> {
73 use mproto::BaseLen;
74
75 let Some(packet) = self.receive_stream.try_next_packet() else {
76 return Ok(None);
77 };
78 packet.advance(modrpc::TransmitPacket::BASE_LEN);
79
80 let stream_item: StreamItemLazy<T> = mproto::decode_value(&packet)?;
81 let payload = stream_item.payload().and_then(|i| T::lazy_to_owned(i))?;
82
83 Ok(Some(payload))
84 }
85
86 pub fn try_next_lazy(&mut self)
87 -> mproto::DecodeResult<Option<mproto::LazyBuf<T, modrpc::Packet>>>
88 {
89 use mproto::BaseLen;
90
91 let Some(packet) = self.receive_stream.try_next_packet() else {
92 return Ok(None);
93 };
94 packet.advance(modrpc::TransmitPacket::BASE_LEN);
95
96 let stream_item: mproto::LazyBuf<StreamItem<T>, _> = mproto::LazyBuf::new(packet);
97 let payload = stream_item.map(|s| s.payload().unwrap());
99
100 Ok(Some(payload))
101 }
102}
103
104struct Subscriptions {
105 stream_states: RefCell<Vec<Rc<StreamState>>>,
106}
107
108pub struct StreamReceiverBuilder<T> {
109 stubs: crate::StreamReceiverStubs<T>,
110 subscriptions: Rc<Subscriptions>,
111}
112
113impl<T: mproto::Owned> StreamReceiverBuilder<T> {
114 pub fn new(
115 _name: &'static str,
116 _hooks: crate::StreamReceiverHooks<T>,
117 stubs: crate::StreamReceiverStubs<T>,
118 _config: &StreamReceiverConfig,
119 _init: StreamInitState,
120 ) -> Self {
121 Self {
122 stubs,
123 subscriptions: Rc::new(Subscriptions {
124 stream_states: RefCell::new(Vec::new()),
125 }),
126 }
127 }
128
129 pub fn create_handle(
130 &self,
131 _setup: &RoleSetup,
132 ) -> crate::StreamReceiver<T> {
133 crate::StreamReceiver {
134 subscriptions: self.subscriptions.clone(),
135 _phantom: core::marker::PhantomData,
136 }
137 }
138
139 pub fn build(
140 self,
141 setup: &RoleSetup,
142 ) {
143 use mproto::BaseLen;
144
145 let subscriptions = self.subscriptions;
146 self.stubs.item.inline_untyped(setup, move |_source, packet| {
147 let stream_item_bytes = &packet[modrpc::TransmitPacket::BASE_LEN..];
148 let Ok(stream_item) =
149 mproto::decode_value::<StreamItemLazy<T>>(stream_item_bytes)
150 else {
151 return;
152 };
153 let Ok(seq) = stream_item.seq() else {
154 return;
155 };
156
157 for stream_state in &mut *subscriptions.stream_states.borrow_mut() {
158 let _stream_is_done = stream_state.handle_item(seq, false, packet.clone());
159 }
160 })
161 .subscribe();
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use modrpc_executor::ModrpcExecutor;
168 use crate::{
169 StreamInitState,
170 StreamSenderBuilder,
171 StreamSenderConfig,
172 StreamSenderRole,
173 StreamReceiverConfig,
174 StreamReceiverRole,
175 };
176 use super::*;
177
178 #[test]
179 fn test_stream_receiver() {
180 let mut ex = modrpc_executor::FuturesExecutor::new();
181 let (rt, _rt_shutdown) = modrpc::RuntimeHandle::single_threaded(&mut ex);
182
183 ex.run_until(async move {
184 let transport = rt.add_transport(modrpc::LocalTransport {
185 buffer_size: 256,
186 buffer_pool_batches: 16,
187 buffer_pool_batch_size: 16,
188 })
189 .await;
190
191 let mut stream_sender = None;
192 let _ =
193 rt.start_role::<StreamSenderRole<String>>(modrpc::RoleConfig {
194 plane_id: 0,
195 endpoint_addr: modrpc::EndpointAddr { endpoint: 0 },
196 transport: transport.clone(),
197 topic_channels: modrpc::TopicChannels::SingleChannel { channel_id: 0 },
198 config: StreamSenderConfig { },
199 init: StreamInitState { },
200 })
201 .local(|cx| {
202 let builder = StreamSenderBuilder::new("stream_sender", cx.hooks.clone(), cx.stubs, cx.config, cx.init.clone());
203 stream_sender = Some(builder.create_handle(cx.setup));
204 builder.build(cx.setup);
205 });
206
207 let mut stream_receiver = None;
208 let _ =
209 rt.start_role::<StreamReceiverRole<String>>(modrpc::RoleConfig {
210 plane_id: 0,
211 endpoint_addr: modrpc::EndpointAddr { endpoint: 0 },
212 transport: transport,
213 topic_channels: modrpc::TopicChannels::SingleChannel { channel_id: 0 },
214 config: StreamReceiverConfig { },
215 init: StreamInitState { },
216 })
217 .local(|cx| {
218 let builder = StreamReceiverBuilder::new("stream_receiver", cx.hooks.clone(), cx.stubs, cx.config, cx.init.clone());
219 stream_receiver = Some(builder.create_handle(cx.setup));
220 builder.build(cx.setup);
221 });
222
223 let mut stream_sender = stream_sender.unwrap();
224 let stream_receiver = stream_receiver.unwrap();
225
226 stream_sender.send("asdf").await;
227
228 let mut subscription = stream_receiver.subscribe(None);
230
231 assert!(matches!(subscription.try_next(), Ok(None)));
232
233 stream_sender.send("foo").await;
234 stream_sender.send("bar").await;
235 stream_sender.send("baz").await;
236
237 assert_eq!(subscription.next().await.unwrap(), "foo");
238 assert_eq!(subscription.next().await.unwrap(), "bar");
239 assert_eq!(subscription.next().await.unwrap(), "baz");
240
241 assert!(matches!(subscription.try_next(), Ok(None)));
242
243 let subscriptions = stream_receiver.subscriptions.clone();
244 assert_eq!(subscriptions.stream_states.borrow().len(), 1);
245 drop(subscription);
246 assert_eq!(subscriptions.stream_states.borrow().len(), 0);
247 });
248 }
249}