zai_rs/model/stream_ext.rs
1//! # Streaming Extensions for Chat-like Endpoints
2//!
3//! This module provides typed streaming capabilities for chat completion APIs
4//! that return Server-Sent Events (SSE) with `ChatStreamResponse` chunks.
5//!
6//! ## Features
7//!
8//! - **Callback-based API** - Simple async closure interface for processing
9//! chunks
10//! - **Stream-based API** - Composable, testable, and reusable stream interface
11//! - **Type-safe parsing** - Automatic deserialization of SSE data chunks
12//! - **Error handling** - Comprehensive error propagation and handling
13//!
14//! ## Usage Patterns
15//!
16//! ### Callback-based Processing
17//! ```rust,ignore
18//! client.stream_for_each(|chunk| async move {
19//! println!("Received: {:?}", chunk);
20//! Ok(())
21//! }).await?;
22//! ```
23//!
24//! ### Stream-based Processing
25//! ```rust,ignore
26//! let mut stream = client.to_stream().await?;
27//! while let Some(result) = stream.next().await {
28//! match result {
29//! Ok(chunk) => println!("Chunk: {:?}", chunk),
30//! Err(e) => eprintln!("Error: {}", e),
31//! }
32//! }
33//! ```
34
35use std::pin::Pin;
36
37use futures::{Stream, StreamExt, stream};
38use log::info;
39
40use crate::{
41 client::http::HttpClient,
42 model::{chat_stream_response::ChatStreamResponse, traits::SseStreamable},
43};
44
45/// Streaming extension trait for chat-like endpoints.
46///
47/// This trait provides two complementary APIs for processing streaming
48/// responses:
49/// 1. **Callback-based** - Simple async closure interface
50/// 2. **Stream-based** - Composable stream interface for advanced usage
51///
52/// Both APIs handle SSE protocol parsing, JSON deserialization, and error
53/// propagation.
54pub trait StreamChatLikeExt: SseStreamable + HttpClient {
55 /// Processes streaming responses using an async callback function.
56 ///
57 /// This method provides a simple interface for handling streaming chat
58 /// responses. Each successfully parsed chunk is passed to the provided
59 /// callback function.
60 ///
61 /// ## Arguments
62 ///
63 /// * `on_chunk` - Async callback function that processes each
64 /// `ChatStreamResponse` chunk
65 ///
66 /// ## Returns
67 ///
68 /// Result indicating success or failure of the streaming operation
69 ///
70 /// ## Example
71 ///
72 /// ```rust,ignore
73 /// client.stream_for_each(|chunk| async move {
74 /// if let Some(content) = &chunk.choices[0].delta.content {
75 /// print!("{}", content);
76 /// }
77 /// Ok(())
78 /// }).await?;
79 /// ```
80 fn stream_for_each<'a, F, Fut>(
81 &'a mut self,
82 mut on_chunk: F,
83 ) -> impl core::future::Future<Output = crate::ZaiResult<()>> + 'a
84 where
85 F: FnMut(ChatStreamResponse) -> Fut + 'a,
86 Fut: core::future::Future<Output = crate::ZaiResult<()>> + 'a,
87 {
88 async move {
89 let resp = self.post().await?;
90 let mut stream = resp.bytes_stream();
91 let mut buf: Vec<u8> = Vec::new();
92
93 while let Some(next) = stream.next().await {
94 let bytes = match next {
95 Ok(b) => b,
96 Err(e) => {
97 return Err(crate::client::error::ZaiError::Unknown {
98 code: 0,
99 message: format!("Stream error: {}", e),
100 });
101 },
102 };
103 buf.extend_from_slice(&bytes);
104 while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
105 let line_vec: Vec<u8> = buf.drain(..=pos).collect();
106 let mut line = &line_vec[..];
107 if line.ends_with(b"\n") {
108 line = &line[..line.len() - 1];
109 }
110 if line.ends_with(b"\r") {
111 line = &line[..line.len() - 1];
112 }
113 if line.is_empty() {
114 continue;
115 }
116 const PREFIX: &[u8] = b"data: ";
117 if line.starts_with(PREFIX) {
118 let rest = &line[PREFIX.len()..];
119 info!("SSE data: {}", String::from_utf8_lossy(rest));
120 if rest == b"[DONE]" {
121 return Ok(());
122 }
123 if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(rest) {
124 on_chunk(chunk).await?;
125 }
126 }
127 }
128 }
129 Ok(())
130 }
131 }
132
133 /// Converts the streaming response into a composable Stream.
134 ///
135 /// This method returns a `Stream` that yields `ChatStreamResponse` chunks,
136 /// enabling advanced stream processing operations like filtering, mapping,
137 /// and combination with other streams.
138 ///
139 /// ## Returns
140 ///
141 /// A future that resolves to a `Stream` of `Result<ChatStreamResponse>`
142 /// items
143 ///
144 /// ## Example
145 ///
146 /// ```rust,ignore
147 /// let stream = client.to_stream().await?;
148 /// let collected: Vec<_> = stream
149 /// .filter_map(|result| result.ok())
150 /// .collect()
151 /// .await;
152 /// ```
153 fn to_stream<'a>(
154 &'a mut self,
155 ) -> impl core::future::Future<
156 Output = crate::ZaiResult<
157 Pin<Box<dyn Stream<Item = crate::ZaiResult<ChatStreamResponse>> + Send + 'static>>,
158 >,
159 > + 'a {
160 async move {
161 let resp = self.post().await?;
162 let byte_stream = resp.bytes_stream();
163
164 // State: (byte_stream, buffer)
165 let s = byte_stream;
166
167 let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
168 loop {
169 // Process all complete lines currently in buffer
170 while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
171 let line_vec: Vec<u8> = buf.drain(..=pos).collect();
172 let mut line = &line_vec[..];
173 if line.ends_with(b"\n") {
174 line = &line[..line.len() - 1];
175 }
176 if line.ends_with(b"\r") {
177 line = &line[..line.len() - 1];
178 }
179 if line.is_empty() {
180 continue;
181 }
182 const PREFIX: &[u8] = b"data: ";
183 if line.starts_with(PREFIX) {
184 let rest = &line[PREFIX.len()..];
185 info!("SSE data: {}", String::from_utf8_lossy(rest));
186 if rest == b"[DONE]" {
187 return None; // end stream gracefully
188 }
189 match serde_json::from_slice::<ChatStreamResponse>(rest) {
190 Ok(item) => return Some((Ok(item), (s, buf))),
191 Err(_) => { /* skip invalid json line */ },
192 }
193 }
194 }
195 // Need more bytes
196 match s.next().await {
197 Some(Ok(bytes)) => buf.extend_from_slice(&bytes),
198 Some(Err(e)) => {
199 return Some((
200 Err(crate::client::error::ZaiError::Unknown {
201 code: 0,
202 message: format!("Stream error: {}", e),
203 }),
204 (s, buf),
205 ));
206 },
207 None => return None,
208 }
209 }
210 })
211 .boxed();
212
213 Ok(out)
214 }
215 }
216}