zai_rs/model/stream_ext.rs
1//! # Streaming Extensions for Chat-like Endpoints
2//!
3//! This module provides typed streaming capabilities for chat completion APIs
4//! that return Server-Sent Events (SSE) with `ChatStreamResponse` chunks.
5//!
6//! ## Features
7//!
8//! - **Callback-based API** - Simple async closure interface for processing chunks
9//! - **Stream-based API** - Composable, testable, and reusable stream interface
10//! - **Type-safe parsing** - Automatic deserialization of SSE data chunks
11//! - **Error handling** - Comprehensive error propagation and handling
12//!
13//! ## Usage Patterns
14//!
15//! ### Callback-based Processing
16//! ```rust,ignore
17//! client.stream_for_each(|chunk| async move {
18//! println!("Received: {:?}", chunk);
19//! Ok(())
20//! }).await?;
21//! ```
22//!
23//! ### Stream-based Processing
24//! ```rust,ignore
25//! let mut stream = client.to_stream().await?;
26//! while let Some(result) = stream.next().await {
27//! match result {
28//! Ok(chunk) => println!("Chunk: {:?}", chunk),
29//! Err(e) => eprintln!("Error: {}", e),
30//! }
31//! }
32//! ```
33
34use crate::client::http::HttpClient;
35use crate::model::chat_stream_response::ChatStreamResponse;
36use crate::model::traits::SseStreamable;
37use futures::{Stream, StreamExt, stream};
38use log::info;
39use std::pin::Pin;
40
41/// Streaming extension trait for chat-like endpoints.
42///
43/// This trait provides two complementary APIs for processing streaming responses:
44/// 1. **Callback-based** - Simple async closure interface
45/// 2. **Stream-based** - Composable stream interface for advanced usage
46///
47/// Both APIs handle SSE protocol parsing, JSON deserialization, and error propagation.
48pub trait StreamChatLikeExt: SseStreamable + HttpClient {
49 /// Processes streaming responses using an async callback function.
50 ///
51 /// This method provides a simple interface for handling streaming chat responses.
52 /// Each successfully parsed chunk is passed to the provided callback function.
53 ///
54 /// ## Arguments
55 ///
56 /// * `on_chunk` - Async callback function that processes each `ChatStreamResponse` chunk
57 ///
58 /// ## Returns
59 ///
60 /// Result indicating success or failure of the streaming operation
61 ///
62 /// ## Example
63 ///
64 /// ```rust,ignore
65 /// client.stream_for_each(|chunk| async move {
66 /// if let Some(content) = &chunk.choices[0].delta.content {
67 /// print!("{}", content);
68 /// }
69 /// Ok(())
70 /// }).await?;
71 /// ```
72 fn stream_for_each<'a, F, Fut>(
73 &'a mut self,
74 mut on_chunk: F,
75 ) -> impl core::future::Future<Output = anyhow::Result<()>> + 'a
76 where
77 F: FnMut(ChatStreamResponse) -> Fut + 'a,
78 Fut: core::future::Future<Output = anyhow::Result<()>> + 'a,
79 {
80 async move {
81 let resp = self.post().await?;
82 let mut stream = resp.bytes_stream();
83 let mut buf: Vec<u8> = Vec::new();
84
85 while let Some(next) = stream.next().await {
86 let bytes = match next {
87 Ok(b) => b,
88 Err(e) => return Err(anyhow::anyhow!("Stream error: {}", e)),
89 };
90 buf.extend_from_slice(&bytes);
91 while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
92 let line_vec: Vec<u8> = buf.drain(..=pos).collect();
93 let mut line = &line_vec[..];
94 if line.ends_with(b"\n") {
95 line = &line[..line.len() - 1];
96 }
97 if line.ends_with(b"\r") {
98 line = &line[..line.len() - 1];
99 }
100 if line.is_empty() {
101 continue;
102 }
103 const PREFIX: &[u8] = b"data: ";
104 if line.starts_with(PREFIX) {
105 let rest = &line[PREFIX.len()..];
106 info!("SSE data: {}", String::from_utf8_lossy(rest));
107 if rest == b"[DONE]" {
108 return Ok(());
109 }
110 if let Ok(chunk) = serde_json::from_slice::<ChatStreamResponse>(rest) {
111 on_chunk(chunk).await?;
112 }
113 }
114 }
115 }
116 Ok(())
117 }
118 }
119
120 /// Converts the streaming response into a composable Stream.
121 ///
122 /// This method returns a `Stream` that yields `ChatStreamResponse` chunks,
123 /// enabling advanced stream processing operations like filtering, mapping,
124 /// and combination with other streams.
125 ///
126 /// ## Returns
127 ///
128 /// A future that resolves to a `Stream` of `Result<ChatStreamResponse>` items
129 ///
130 /// ## Example
131 ///
132 /// ```rust,ignore
133 /// let stream = client.to_stream().await?;
134 /// let collected: Vec<_> = stream
135 /// .filter_map(|result| result.ok())
136 /// .collect()
137 /// .await;
138 /// ```
139 fn to_stream<'a>(
140 &'a mut self,
141 ) -> impl core::future::Future<
142 Output = anyhow::Result<
143 Pin<Box<dyn Stream<Item = anyhow::Result<ChatStreamResponse>> + Send + 'static>>,
144 >,
145 > + 'a {
146 async move {
147 let resp = self.post().await?;
148 let byte_stream = resp.bytes_stream();
149
150 // State: (byte_stream, buffer)
151 let s = byte_stream;
152
153 let out = stream::unfold((s, Vec::<u8>::new()), |(mut s, mut buf)| async move {
154 loop {
155 // Process all complete lines currently in buffer
156 while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
157 let line_vec: Vec<u8> = buf.drain(..=pos).collect();
158 let mut line = &line_vec[..];
159 if line.ends_with(b"\n") {
160 line = &line[..line.len() - 1];
161 }
162 if line.ends_with(b"\r") {
163 line = &line[..line.len() - 1];
164 }
165 if line.is_empty() {
166 continue;
167 }
168 const PREFIX: &[u8] = b"data: ";
169 if line.starts_with(PREFIX) {
170 let rest = &line[PREFIX.len()..];
171 info!("SSE data: {}", String::from_utf8_lossy(rest));
172 if rest == b"[DONE]" {
173 return None; // end stream gracefully
174 }
175 match serde_json::from_slice::<ChatStreamResponse>(rest) {
176 Ok(item) => return Some((Ok(item), (s, buf))),
177 Err(_) => { /* skip invalid json line */ }
178 }
179 }
180 }
181 // Need more bytes
182 match s.next().await {
183 Some(Ok(bytes)) => buf.extend_from_slice(&bytes),
184 Some(Err(e)) => {
185 return Some((Err(anyhow::anyhow!("Stream error: {}", e)), (s, buf)));
186 }
187 None => return None,
188 }
189 }
190 })
191 .boxed();
192
193 Ok(out)
194 }
195 }
196}