zai_rs/model/
stream_ext.rs

1//! # Streaming Extensions for Chat-like Endpoints
2//!
3//! This module provides typed streaming capabilities for chat completion APIs
4//! that return Server-Sent Events (SSE) with `ChatStreamResponse` chunks.
5//!
6//! ## Features
7//!
8//! - **Callback-based API** - Simple async closure interface for processing chunks
9//! - **Stream-based API** - Composable, testable, and reusable stream interface
10//! - **Type-safe parsing** - Automatic deserialization of SSE data chunks
11//! - **Error handling** - Comprehensive error propagation and handling
12//!
13//! ## Usage Patterns
14//!
15//! ### Callback-based Processing
16//! ```rust,ignore
17//! client.stream_for_each(|chunk| async move {
18//!     println!("Received: {:?}", chunk);
19//!     Ok(())
20//! }).await?;
21//! ```
22//!
23//! ### Stream-based Processing
24//! ```rust,ignore
25//! let mut stream = client.to_stream().await?;
26//! while let Some(result) = stream.next().await {
27//!     match result {
28//!         Ok(chunk) => println!("Chunk: {:?}", chunk),
29//!         Err(e) => eprintln!("Error: {}", e),
30//!     }
31//! }
32//! ```
33
34use crate::client::http::HttpClient;
35use crate::model::chat_stream_response::ChatStreamResponse;
36use crate::model::traits::SseStreamable;
37use futures::{Stream, StreamExt, stream};
38use log::info;
39use std::pin::Pin;
40
41/// Streaming extension trait for chat-like endpoints.
42///
43/// This trait provides two complementary APIs for processing streaming responses:
44/// 1. **Callback-based** - Simple async closure interface
45/// 2. **Stream-based** - Composable stream interface for advanced usage
46///
47/// Both APIs handle SSE protocol parsing, JSON deserialization, and error propagation.
48pub trait StreamChatLikeExt: SseStreamable + HttpClient {
49    /// Processes streaming responses using an async callback function.
50    ///
51    /// This method provides a simple interface for handling streaming chat responses.
52    /// Each successfully parsed chunk is passed to the provided callback function.
53    ///
54    /// ## Arguments
55    ///
56    /// * `on_chunk` - Async callback function that processes each `ChatStreamResponse` chunk
57    ///
58    /// ## Returns
59    ///
60    /// Result indicating success or failure of the streaming operation
61    ///
62    /// ## Example
63    ///
64    /// ```rust,ignore
65    /// client.stream_for_each(|chunk| async move {
66    ///     if let Some(content) = &chunk.choices[0].delta.content {
67    ///         print!("{}", content);
68    ///     }
69    ///     Ok(())
70    /// }).await?;
71    /// ```
72    fn stream_for_each<'a, F, Fut>(
73        &'a mut self,
74        mut on_chunk: F,
75    ) -> impl core::future::Future<Output = anyhow::Result<()>> + 'a
76    where
77        F: FnMut(ChatStreamResponse) -> Fut + 'a,
78        Fut: core::future::Future<Output = anyhow::Result<()>> + 'a,
79    {
80        async move {
81            let resp = self.post().await?;
82            let mut stream = resp.bytes_stream();
83            let mut buf: Vec<u8> = Vec::new();
84
85            while let Some(next) = stream.next().await {
86                let bytes = match next {
87                    Ok(b) => b,
88                    Err(e) => return Err(anyhow::anyhow!("Stream error: {}", e)),
89                };
90                buf.extend_from_slice(&bytes);
91                while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
92                    let line_vec: Vec<u8> = buf.drain(..=pos).collect();
93                    let mut line = &line_vec[..];
94                    if line.ends_with(b"\n") {
95                        line = &line[..line.len() - 1];
96                    }
97                    if line.ends_with(b"\r") {
98                        line = &line[..line.len() - 1];
99                    }
100                    if line.is_empty() {
101                        continue;
102                    }
103                    const PREFIX: &[u8] = b"data: ";
104                    if line.starts_with(PREFIX) {
105                        let rest = &line[PREFIX.len()..];
106                        info!("SSE data: {}", String::from_utf8_lossy(rest));
107                        if rest == b"[DONE]" {
108                            return Ok(());
109                        }
110                        if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(rest) {
111                            on_chunk(chunk).await?;
112                        }
113                    }
114                }
115            }
116            Ok(())
117        }
118    }
119
120    /// Converts the streaming response into a composable Stream.
121    ///
122    /// This method returns a `Stream` that yields `ChatStreamResponse` chunks,
123    /// enabling advanced stream processing operations like filtering, mapping,
124    /// and combination with other streams.
125    ///
126    /// ## Returns
127    ///
128    /// A future that resolves to a `Stream` of `Result<ChatStreamResponse>` items
129    ///
130    /// ## Example
131    ///
132    /// ```rust,ignore
133    /// let stream = client.to_stream().await?;
134    /// let collected: Vec<_> = stream
135    ///     .filter_map(|result| result.ok())
136    ///     .collect()
137    ///     .await;
138    /// ```
139    fn to_stream<'a>(
140        &'a mut self,
141    ) -> impl core::future::Future<
142        Output = anyhow::Result<
143            Pin<Box<dyn Stream<Item = anyhow::Result<ChatStreamResponse>> + Send + 'static>>,
144        >,
145    > + 'a {
146        async move {
147            let resp = self.post().await?;
148            let byte_stream = resp.bytes_stream();
149
150            // State: (byte_stream, buffer)
151            let s = byte_stream;
152
153            let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
154                loop {
155                    // Process all complete lines currently in buffer
156                    while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
157                        let line_vec: Vec<u8> = buf.drain(..=pos).collect();
158                        let mut line = &line_vec[..];
159                        if line.ends_with(b"\n") {
160                            line = &line[..line.len() - 1];
161                        }
162                        if line.ends_with(b"\r") {
163                            line = &line[..line.len() - 1];
164                        }
165                        if line.is_empty() {
166                            continue;
167                        }
168                        const PREFIX: &[u8] = b"data: ";
169                        if line.starts_with(PREFIX) {
170                            let rest = &line[PREFIX.len()..];
171                            info!("SSE data: {}", String::from_utf8_lossy(rest));
172                            if rest == b"[DONE]" {
173                                return None; // end stream gracefully
174                            }
175                            match serde_json::from_slice::<ChatStreamResponse>(rest) {
176                                Ok(item) => return Some((Ok(item), (s, buf))),
177                                Err(_) => { /* skip invalid json line */ }
178                            }
179                        }
180                    }
181                    // Need more bytes
182                    match s.next().await {
183                        Some(Ok(bytes)) => buf.extend_from_slice(&bytes),
184                        Some(Err(e)) => {
185                            return Some((Err(anyhow::anyhow!("Stream error: {}", e)), (s, buf)));
186                        }
187                        None => return None,
188                    }
189                }
190            })
191            .boxed();
192
193            Ok(out)
194        }
195    }
196}