Skip to main content

zai_rs/model/
stream_ext.rs

1//! # Streaming Extensions for Chat-like Endpoints
2//!
3//! This module provides typed streaming capabilities for chat completion APIs
4//! that return Server-Sent Events (SSE) with `ChatStreamResponse` chunks.
5//!
6//! ## Features
7//!
8//! - **Callback-based API** - Simple async closure interface for processing
9//!   chunks
10//! - **Stream-based API** - Composable, testable, and reusable stream interface
11//! - **Type-safe parsing** - Automatic deserialization of SSE data chunks
12//! - **Error handling** - Comprehensive error propagation and handling
13//!
14//! ## Usage Patterns
15//!
16//! ### Callback-based Processing
17//! ```rust,ignore
18//! client.stream_for_each(|chunk| async move {
19//!     println!("Received: {:?}", chunk);
20//!     Ok(())
21//! }).await?;
22//! ```
23//!
24//! ### Stream-based Processing
25//! ```rust,ignore
26//! let mut stream = client.to_stream().await?;
27//! while let Some(result) = stream.next().await {
28//!     match result {
29//!         Ok(chunk) => println!("Chunk: {:?}", chunk),
30//!         Err(e) => eprintln!("Error: {}", e),
31//!     }
32//! }
33//! ```
34
35use std::pin::Pin;
36
37use futures::{Stream, StreamExt, stream};
38use log::info;
39
40use crate::{
41    client::http::HttpClient,
42    model::{chat_stream_response::ChatStreamResponse, traits::SseStreamable},
43};
44
45/// Streaming extension trait for chat-like endpoints.
46///
47/// This trait provides two complementary APIs for processing streaming
48/// responses:
49/// 1. **Callback-based** - Simple async closure interface
50/// 2. **Stream-based** - Composable stream interface for advanced usage
51///
52/// Both APIs handle SSE protocol parsing, JSON deserialization, and error
53/// propagation.
54pub trait StreamChatLikeExt: SseStreamable + HttpClient {
55    /// Processes streaming responses using an async callback function.
56    ///
57    /// This method provides a simple interface for handling streaming chat
58    /// responses. Each successfully parsed chunk is passed to the provided
59    /// callback function.
60    ///
61    /// ## Arguments
62    ///
63    /// * `on_chunk` - Async callback function that processes each
64    ///   `ChatStreamResponse` chunk
65    ///
66    /// ## Returns
67    ///
68    /// Result indicating success or failure of the streaming operation
69    ///
70    /// ## Example
71    ///
72    /// ```rust,ignore
73    /// client.stream_for_each(|chunk| async move {
74    ///     if let Some(content) = &chunk.choices[0].delta.content {
75    ///         print!("{}", content);
76    ///     }
77    ///     Ok(())
78    /// }).await?;
79    /// ```
80    fn stream_for_each<'a, F, Fut>(
81        &'a mut self,
82        mut on_chunk: F,
83    ) -> impl core::future::Future<Output = crate::ZaiResult<()>> + 'a
84    where
85        F: FnMut(ChatStreamResponse) -> Fut + 'a,
86        Fut: core::future::Future<Output = crate::ZaiResult<()>> + 'a,
87    {
88        async move {
89            let resp = self.post().await?;
90            let mut stream = resp.bytes_stream();
91            let mut buf: Vec<u8> = Vec::new();
92
93            while let Some(next) = stream.next().await {
94                let bytes = match next {
95                    Ok(b) => b,
96                    Err(e) => {
97                        return Err(crate::client::error::ZaiError::Unknown {
98                            code: 0,
99                            message: format!("Stream error: {}", e),
100                        });
101                    },
102                };
103                buf.extend_from_slice(&bytes);
104                while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
105                    let line_vec: Vec<u8> = buf.drain(..=pos).collect();
106                    let mut line = &line_vec[..];
107                    if line.ends_with(b"\n") {
108                        line = &line[..line.len() - 1];
109                    }
110                    if line.ends_with(b"\r") {
111                        line = &line[..line.len() - 1];
112                    }
113                    if line.is_empty() {
114                        continue;
115                    }
116                    const PREFIX: &[u8] = b"data: ";
117                    if line.starts_with(PREFIX) {
118                        let rest = &line[PREFIX.len()..];
119                        info!("SSE data: {}", String::from_utf8_lossy(rest));
120                        if rest == b"[DONE]" {
121                            return Ok(());
122                        }
123                        if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(rest) {
124                            on_chunk(chunk).await?;
125                        }
126                    }
127                }
128            }
129            Ok(())
130        }
131    }
132
133    /// Converts the streaming response into a composable Stream.
134    ///
135    /// This method returns a `Stream` that yields `ChatStreamResponse` chunks,
136    /// enabling advanced stream processing operations like filtering, mapping,
137    /// and combination with other streams.
138    ///
139    /// ## Returns
140    ///
141    /// A future that resolves to a `Stream` of `Result<ChatStreamResponse>`
142    /// items
143    ///
144    /// ## Example
145    ///
146    /// ```rust,ignore
147    /// let stream = client.to_stream().await?;
148    /// let collected: Vec<_> = stream
149    ///     .filter_map(|result| result.ok())
150    ///     .collect()
151    ///     .await;
152    /// ```
153    fn to_stream<'a>(
154        &'a mut self,
155    ) -> impl core::future::Future<
156        Output = crate::ZaiResult<
157            Pin<Box<dyn Stream<Item = crate::ZaiResult<ChatStreamResponse>> + Send + 'static>>,
158        >,
159    > + 'a {
160        async move {
161            let resp = self.post().await?;
162            let byte_stream = resp.bytes_stream();
163
164            // State: (byte_stream, buffer)
165            let s = byte_stream;
166
167            let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
168                loop {
169                    // Process all complete lines currently in buffer
170                    while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
171                        let line_vec: Vec<u8> = buf.drain(..=pos).collect();
172                        let mut line = &line_vec[..];
173                        if line.ends_with(b"\n") {
174                            line = &line[..line.len() - 1];
175                        }
176                        if line.ends_with(b"\r") {
177                            line = &line[..line.len() - 1];
178                        }
179                        if line.is_empty() {
180                            continue;
181                        }
182                        const PREFIX: &[u8] = b"data: ";
183                        if line.starts_with(PREFIX) {
184                            let rest = &line[PREFIX.len()..];
185                            info!("SSE data: {}", String::from_utf8_lossy(rest));
186                            if rest == b"[DONE]" {
187                                return None; // end stream gracefully
188                            }
189                            match serde_json::from_slice::<ChatStreamResponse>(rest) {
190                                Ok(item) => return Some((Ok(item), (s, buf))),
191                                Err(_) => { /* skip invalid json line */ },
192                            }
193                        }
194                    }
195                    // Need more bytes
196                    match s.next().await {
197                        Some(Ok(bytes)) => buf.extend_from_slice(&bytes),
198                        Some(Err(e)) => {
199                            return Some((
200                                Err(crate::client::error::ZaiError::Unknown {
201                                    code: 0,
202                                    message: format!("Stream error: {}", e),
203                                }),
204                                (s, buf),
205                            ));
206                        },
207                        None => return None,
208                    }
209                }
210            })
211            .boxed();
212
213            Ok(out)
214        }
215    }
216}