1use bytes::Bytes;
2use futures::Stream;
3use parking_lot::Mutex;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11
12use crate::channel::message::MessageChannel;
13use crate::codec::{BincodeCodec, Codec};
14use crate::error::{Result, RpcError};
15use crate::message::Message;
16use crate::message::types::{MessageId, MessageType};
17
18pub type StreamId = u64;
19
20static STREAM_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
21
22pub fn next_stream_id() -> StreamId {
23 STREAM_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
24}
25
26pub struct StreamSender<T: MessageChannel, C: Codec = BincodeCodec> {
27 stream_id: StreamId,
28 sequence: AtomicU64,
29 transport: Arc<T>,
30 codec: C,
31 ended: std::sync::atomic::AtomicBool,
32}
33
34impl<T: MessageChannel> StreamSender<T, BincodeCodec> {
35 pub fn new(stream_id: StreamId, transport: Arc<T>) -> Self {
36 Self {
37 stream_id,
38 sequence: AtomicU64::new(0),
39 transport,
40 codec: BincodeCodec,
41 ended: std::sync::atomic::AtomicBool::new(false),
42 }
43 }
44}
45
46impl<T: MessageChannel, C: Codec> StreamSender<T, C> {
47 pub fn with_codec(stream_id: StreamId, transport: Arc<T>, codec: C) -> Self {
48 Self {
49 stream_id,
50 sequence: AtomicU64::new(0),
51 transport,
52 codec,
53 ended: std::sync::atomic::AtomicBool::new(false),
54 }
55 }
56
57 pub async fn send<D: Serialize>(&self, data: D) -> Result<()> {
58 if self.ended.load(Ordering::Acquire) {
59 return Err(RpcError::StreamError("Stream already ended".to_string()));
60 }
61
62 let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
63 let payload = self.codec.encode(&data)?;
64 let chunk = Message::new(
65 MessageId::new(),
66 MessageType::StreamChunk,
67 "",
68 Bytes::from(payload),
69 crate::message::metadata::MessageMetadata::new().with_stream(self.stream_id, seq),
70 );
71 self.transport
72 .send(&chunk)
73 .await
74 .map_err(RpcError::Transport)
75 }
76
77 pub async fn end(&self) -> Result<()> {
78 if self
79 .ended
80 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
81 .is_err()
82 {
83 return Ok(());
84 }
85
86 let end_msg: Message = Message::stream_end(self.stream_id);
87 self.transport
88 .send(&end_msg)
89 .await
90 .map_err(RpcError::Transport)
91 }
92
93 pub fn stream_id(&self) -> StreamId {
94 self.stream_id
95 }
96
97 pub fn is_ended(&self) -> bool {
98 self.ended.load(Ordering::Acquire)
99 }
100}
101
102pub struct StreamReceiver<D, C: Codec = BincodeCodec> {
104 stream_id: StreamId,
105 rx: mpsc::UnboundedReceiver<Result<Bytes>>,
106 ended: bool,
107 codec: C,
108 _phantom: std::marker::PhantomData<D>,
109}
110
111impl<D, C> StreamReceiver<D, C>
112where
113 D: for<'de> Deserialize<'de>,
114 C: Codec,
115{
116 pub(crate) fn new(
117 stream_id: StreamId,
118 rx: mpsc::UnboundedReceiver<Result<Bytes>>,
119 codec: C,
120 ) -> Self {
121 Self {
122 stream_id,
123 rx,
124 ended: false,
125 codec,
126 _phantom: std::marker::PhantomData,
127 }
128 }
129
130 pub fn stream_id(&self) -> StreamId {
131 self.stream_id
132 }
133
134 pub fn is_ended(&self) -> bool {
135 self.ended
136 }
137
138 pub async fn recv(&mut self) -> Option<Result<D>> {
139 if self.ended {
140 return None;
141 }
142
143 match self.rx.recv().await {
144 Some(Ok(data)) => Some(self.codec.decode(&data)),
145 Some(Err(e)) => {
146 self.ended = true;
147 Some(Err(e))
148 }
149 None => {
150 self.ended = true;
151 None
152 }
153 }
154 }
155
156 pub async fn collect(mut self) -> Result<Vec<D>> {
157 let mut items = Vec::new();
158 while let Some(result) = self.recv().await {
159 items.push(result?);
160 }
161 Ok(items)
162 }
163
164 pub fn cancel(&mut self) {
165 self.ended = true;
166 self.rx.close();
167 }
168}
169
170impl<D, C> Stream for StreamReceiver<D, C>
171where
172 D: for<'de> Deserialize<'de> + Unpin,
173 C: Codec + Unpin,
174{
175 type Item = Result<D>;
176
177 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 if self.ended {
179 return Poll::Ready(None);
180 }
181
182 match self.rx.poll_recv(cx) {
183 Poll::Ready(Some(Ok(data))) => {
184 let result = self.codec.decode(&data);
185 Poll::Ready(Some(result))
186 }
187 Poll::Ready(Some(Err(e))) => {
188 self.ended = true;
189 Poll::Ready(Some(Err(e)))
190 }
191 Poll::Ready(None) => {
192 self.ended = true;
193 Poll::Ready(None)
194 }
195 Poll::Pending => Poll::Pending,
196 }
197 }
198}
199
200pub struct StreamManager<C: Codec = BincodeCodec> {
201 streams: Arc<Mutex<HashMap<StreamId, mpsc::UnboundedSender<Result<Bytes>>>>>,
202 codec: C,
203}
204
205impl StreamManager<BincodeCodec> {
206 pub fn new() -> Self {
207 Self {
208 streams: Arc::new(Mutex::new(HashMap::new())),
209 codec: BincodeCodec,
210 }
211 }
212}
213
214impl<C: Codec + Clone> StreamManager<C> {
215 pub fn with_codec(codec: C) -> Self {
216 Self {
217 streams: Arc::new(Mutex::new(HashMap::new())),
218 codec,
219 }
220 }
221
222 pub fn create_receiver<D>(&self, stream_id: StreamId) -> StreamReceiver<D, C>
223 where
224 D: for<'de> Deserialize<'de>,
225 {
226 let (tx, rx) = mpsc::unbounded_channel();
227 self.streams.lock().insert(stream_id, tx);
228 StreamReceiver::new(stream_id, rx, self.codec.clone())
229 }
230
231 pub fn handle_message(&self, message: &Message<C>) -> bool {
232 let stream_id = match message.metadata.stream_id {
233 Some(id) => id,
234 None => return false,
235 };
236
237 let streams = self.streams.lock();
238 let sender = match streams.get(&stream_id) {
239 Some(tx) => tx,
240 None => return false,
241 };
242
243 match message.msg_type {
244 MessageType::StreamChunk => {
245 let _ = sender.send(Ok(message.payload.clone()));
246 true
247 }
248 MessageType::StreamEnd => {
249 drop(streams);
250 self.remove_stream(stream_id);
251 true
252 }
253 _ => false,
254 }
255 }
256
257 pub fn send_error(&self, stream_id: StreamId, error_msg: String) {
259 let streams = self.streams.lock();
260 if let Some(sender) = streams.get(&stream_id) {
261 let _ = sender.send(Err(RpcError::ServerError(error_msg)));
262 }
263 drop(streams);
264 self.remove_stream(stream_id);
265 }
266
267 pub fn remove_stream(&self, stream_id: StreamId) {
268 self.streams.lock().remove(&stream_id);
269 }
270
271 pub fn active_stream_count(&self) -> usize {
272 self.streams.lock().len()
273 }
274}
275
276impl Default for StreamManager<BincodeCodec> {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::channel::message::MessageChannelAdapter;
286 use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
287
288 #[tokio::test]
289 async fn test_stream_sender_receiver() {
290 let config = ChannelConfig::default();
291 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
292
293 let sender_channel = Arc::new(MessageChannelAdapter::new(t1));
294 let receiver_channel = MessageChannelAdapter::new(t2);
295
296 let stream_id = next_stream_id();
297 let sender = StreamSender::new(stream_id, sender_channel);
298
299 let manager = StreamManager::new();
300 let receiver: StreamReceiver<i32> = manager.create_receiver(stream_id);
301
302 let recv_handle = tokio::spawn(async move {
303 loop {
304 let msg = receiver_channel.recv().await.unwrap();
305 if !manager.handle_message(&msg) {
306 break;
307 }
308 if msg.msg_type == MessageType::StreamEnd {
309 break;
310 }
311 }
312 receiver
313 });
314
315 sender.send(1i32).await.unwrap();
316 sender.send(2i32).await.unwrap();
317 sender.send(3i32).await.unwrap();
318 sender.end().await.unwrap();
319
320 let mut receiver = recv_handle.await.unwrap();
321
322 let items: Vec<i32> = vec![
323 receiver.recv().await.unwrap().unwrap(),
324 receiver.recv().await.unwrap().unwrap(),
325 receiver.recv().await.unwrap().unwrap(),
326 ];
327 assert_eq!(items, vec![1, 2, 3]);
328 }
329
330 #[test]
331 fn test_stream_id_generation() {
332 let id1 = next_stream_id();
333 let id2 = next_stream_id();
334 assert_ne!(id1, id2);
335 }
336}