Skip to main content

zai_rs/model/
stream_ext.rs

1//! # Streaming Extensions for Chat-like Endpoints
2//!
3//! This module provides typed streaming capabilities for chat completion APIs
4//! that return Server-Sent Events (SSE) with `ChatStreamResponse` chunks.
5//!
6//! ## Features
7//!
8//! - **Callback-based API** - Simple async closure interface for processing
9//!   chunks
10//! - **Stream-based API** - Composable, testable, and reusable stream interface
11//! - **Type-safe parsing** - Automatic deserialization of SSE data chunks
12//! - **Error handling** - Comprehensive error propagation and handling
13//!
14//! ## Usage Patterns
15//!
16//! ### Callback-based Processing
17//! ```rust,ignore
18//! client.stream_for_each(|chunk| async move {
19//!     println!("Received: {:?}", chunk);
20//!     Ok(())
21//! }).await?;
22//! ```
23//!
24//! ### Stream-based Processing
25//! ```rust,ignore
26//! let mut stream = client.to_stream().await?;
27//! while let Some(result) = stream.next().await {
28//!     match result {
29//!         Ok(chunk) => println!("Chunk: {:?}", chunk),
30//!         Err(e) => eprintln!("Error: {}", e),
31//!     }
32//! }
33//! ```
34
35use std::pin::Pin;
36
37use futures::{Stream, StreamExt, stream};
38use tracing::info;
39
40use crate::{
41    client::http::HttpClient,
42    model::{chat_stream_response::ChatStreamResponse, traits::SseStreamable},
43};
44
45/// Streaming extension trait for chat-like endpoints.
46///
47/// This trait provides two complementary APIs for processing streaming
48/// responses:
49/// 1. **Callback-based** - Simple async closure interface
50/// 2. **Stream-based** - Composable stream interface for advanced usage
51///
52/// Both APIs handle SSE protocol parsing, JSON deserialization, and error
53/// propagation.
54pub trait StreamChatLikeExt: SseStreamable + HttpClient {
55    /// Processes streaming responses using an async callback function.
56    ///
57    /// This method provides a simple interface for handling streaming chat
58    /// responses. Each successfully parsed chunk is passed to the provided
59    /// callback function.
60    ///
61    /// ## Arguments
62    ///
63    /// * `on_chunk` - Async callback function that processes each
64    ///   `ChatStreamResponse` chunk
65    ///
66    /// ## Returns
67    ///
68    /// Result indicating success or failure of the streaming operation
69    ///
70    /// ## Example
71    ///
72    /// ```rust,ignore
73    /// client.stream_for_each(|chunk| async move {
74    ///     if let Some(content) = &chunk.choices[0].delta.content {
75    ///         print!("{}", content);
76    ///     }
77    ///     Ok(())
78    /// }).await?;
79    /// ```
80    fn stream_for_each<'a, F, Fut>(
81        &'a mut self,
82        mut on_chunk: F,
83    ) -> impl core::future::Future<Output = crate::ZaiResult<()>> + 'a
84    where
85        F: FnMut(ChatStreamResponse) -> Fut + 'a,
86        Fut: core::future::Future<Output = crate::ZaiResult<()>> + 'a,
87    {
88        async move {
89            let resp = self.post().await?;
90            let mut stream = resp.bytes_stream();
91            let mut buf: Vec<u8> = Vec::new();
92
93            while let Some(next) = stream.next().await {
94                let bytes = match next {
95                    Ok(b) => b,
96                    Err(e) => {
97                        return Err(crate::client::error::ZaiError::NetworkError(
98                            std::sync::Arc::new(e),
99                        ));
100                    },
101                };
102                let lines = crate::model::sse_parser::extract_sse_data_lines(&mut buf, &bytes);
103                for rest in lines {
104                    info!("SSE data: {}", String::from_utf8_lossy(&rest));
105                    if rest == b"[DONE]" {
106                        return Ok(());
107                    }
108                    if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(&rest) {
109                        on_chunk(chunk).await?;
110                    }
111                }
112            }
113            Ok(())
114        }
115    }
116
117    /// Converts the streaming response into a composable Stream.
118    ///
119    /// This method returns a `Stream` that yields `ChatStreamResponse` chunks,
120    /// enabling advanced stream processing operations like filtering, mapping,
121    /// and combination with other streams.
122    ///
123    /// ## Returns
124    ///
125    /// A future that resolves to a `Stream` of `Result<ChatStreamResponse>`
126    /// items
127    ///
128    /// ## Example
129    ///
130    /// ```rust,ignore
131    /// let stream = client.to_stream().await?;
132    /// let collected: Vec<_> = stream
133    ///     .filter_map(|result| result.ok())
134    ///     .collect()
135    ///     .await;
136    /// ```
137    fn to_stream<'a>(
138        &'a mut self,
139    ) -> impl core::future::Future<
140        Output = crate::ZaiResult<
141            Pin<Box<dyn Stream<Item = crate::ZaiResult<ChatStreamResponse>> + Send + 'static>>,
142        >,
143    > + 'a {
144        async move {
145            let resp = self.post().await?;
146            let byte_stream = resp.bytes_stream();
147
148            let s = byte_stream;
149
150            let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
151                loop {
152                    // Need more bytes first to populate buffer
153                    match s.next().await {
154                        Some(Ok(bytes)) => {
155                            let lines =
156                                crate::model::sse_parser::extract_sse_data_lines(&mut buf, &bytes);
157                            for rest in lines {
158                                info!("SSE data: {}", String::from_utf8_lossy(&rest));
159                                if rest == b"[DONE]" {
160                                    return None; // end stream gracefully
161                                }
162                                if let Ok(item) =
163                                    serde_json::from_slice::<ChatStreamResponse>(&rest)
164                                {
165                                    return Some((Ok(item), (s, buf)));
166                                }
167                                // skip invalid json line, continue processing
168                                // remaining lines
169                            }
170                            // All lines processed but no valid
171                            // ChatStreamResponse yielded,
172                            // loop back to get more bytes
173                        },
174                        Some(Err(e)) => {
175                            return Some((
176                                Err(crate::client::error::ZaiError::NetworkError(
177                                    std::sync::Arc::new(e),
178                                )),
179                                (s, buf),
180                            ));
181                        },
182                        None => return None,
183                    }
184                }
185            })
186            .boxed();
187
188            Ok(out)
189        }
190    }
191}