Skip to main content

xai_rust/
stream.rs

1//! Server-Sent Events (SSE) streaming support.
2//!
3//! # Streaming Function Calls
4//!
5//! When streaming responses that include function/tool calls, the arguments
6//! arrive as deltas across multiple chunks. You must accumulate the argument
7//! strings manually:
8//!
9//! ```rust,no_run
10//! # use xai_rust::XaiClient;
11//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
12//! # let client = XaiClient::from_env()?;
13//! use futures_util::StreamExt;
14//! use std::collections::HashMap;
15//!
16//! let mut stream = client.responses()
17//!     .create("grok-4")
18//!     .user("What's the weather in NYC?")
19//!     .tool(xai_rust::chat::tool(
20//!         "get_weather", "Get weather",
21//!         serde_json::json!({"type": "object", "properties": {"city": {"type": "string"}}}),
22//!     ))
23//!     .stream()
24//!     .await?;
25//!
26//! let mut tool_args: HashMap<String, String> = HashMap::new();
27//!
28//! while let Some(chunk) = stream.next().await {
29//!     let chunk = chunk?;
30//!     for tc in &chunk.tool_calls {
31//!         if let Some(ref func) = tc.function {
32//!             tool_args.entry(tc.id.clone())
33//!                 .or_default()
34//!                 .push_str(&func.arguments);
35//!         }
36//!     }
37//! }
38//! # Ok(())
39//! # }
40
41use bytes::Bytes;
42use futures_util::Stream;
43use pin_project_lite::pin_project;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use crate::error::{Error, Result};
48use crate::models::response::{Response, StreamChunk};
49use crate::models::tool::ToolCall;
50
51pin_project! {
52    /// A stream of response chunks from the API.
53    pub struct ResponseStream {
54        #[pin]
55        inner: Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>,
56        buffer: String,
57        raw_buffer: Vec<u8>,
58        done: bool,
59    }
60}
61
62impl ResponseStream {
63    /// Create a new response stream from a byte stream.
64    pub fn new<S>(stream: S) -> Self
65    where
66        S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + 'static,
67    {
68        use futures_util::StreamExt;
69
70        let mapped = stream.map(|result| result.map_err(Error::from));
71        Self {
72            inner: Box::pin(mapped),
73            buffer: String::new(),
74            raw_buffer: Vec::new(),
75            done: false,
76        }
77    }
78
79    /// Parse an SSE event line into a stream chunk.
80    fn parse_event(data: &str) -> Result<Option<StreamChunk>> {
81        let data = data.trim();
82
83        // Skip empty lines and comments
84        if data.is_empty() || data.starts_with(':') {
85            return Ok(None);
86        }
87
88        // Check for [DONE] marker
89        if data == "[DONE]" {
90            return Ok(Some(StreamChunk {
91                delta: None,
92                reasoning_delta: None,
93                tool_calls: vec![],
94                done: true,
95                response: None,
96            }));
97        }
98
99        // Parse the JSON data
100        let event: StreamEvent = serde_json::from_str(data)?;
101
102        match event {
103            StreamEvent::ResponseDelta(delta) => {
104                let text_delta = delta.delta.and_then(|d| {
105                    d.content.and_then(|parts| {
106                        let text = parts
107                            .into_iter()
108                            .filter_map(|part| {
109                                if let DeltaContentPart::Text { text } = part {
110                                    Some(text)
111                                } else {
112                                    None
113                                }
114                            })
115                            .collect::<String>();
116                        if text.is_empty() {
117                            None
118                        } else {
119                            Some(text)
120                        }
121                    })
122                });
123
124                Ok(Some(StreamChunk {
125                    delta: text_delta,
126                    reasoning_delta: None,
127                    tool_calls: vec![],
128                    done: false,
129                    response: None,
130                }))
131            }
132            StreamEvent::ResponseDone(done) => Ok(Some(StreamChunk {
133                delta: None,
134                reasoning_delta: None,
135                tool_calls: vec![],
136                done: true,
137                response: Some(done.response),
138            })),
139            StreamEvent::ResponseToolCallDelta(delta) => {
140                let tool_call = delta.delta.map(|d| ToolCall {
141                    id: delta.tool_call_id.unwrap_or_default(),
142                    call_type: Some("function".to_string()),
143                    function: d.function,
144                });
145
146                Ok(Some(StreamChunk {
147                    delta: None,
148                    reasoning_delta: None,
149                    tool_calls: tool_call.into_iter().collect(),
150                    done: false,
151                    response: None,
152                }))
153            }
154            _ => Ok(None),
155        }
156    }
157
158    /// Parse a full SSE event block and return a stream chunk if present.
159    fn parse_sse_event(event_block: &str) -> Result<Option<StreamChunk>> {
160        let mut first_data_line: Option<&str> = None;
161        let mut merged_payload: Option<String> = None;
162
163        for line in event_block.lines() {
164            if line.is_empty() || line.starts_with(':') {
165                continue;
166            }
167
168            if line == "data" {
169                if let Some(payload) = merged_payload.as_mut() {
170                    payload.push('\n');
171                } else if let Some(first) = first_data_line {
172                    let mut payload = String::with_capacity(first.len() + 1);
173                    payload.push_str(first);
174                    payload.push('\n');
175                    merged_payload = Some(payload);
176                } else {
177                    first_data_line = Some("");
178                }
179            } else if let Some(data) = line.strip_prefix("data:") {
180                // SSE allows optional single space after "data:"
181                let payload_line = data.strip_prefix(' ').unwrap_or(data);
182                if let Some(payload) = merged_payload.as_mut() {
183                    payload.push('\n');
184                    payload.push_str(payload_line);
185                } else if let Some(first) = first_data_line {
186                    let mut payload = String::with_capacity(first.len() + 1 + payload_line.len());
187                    payload.push_str(first);
188                    payload.push('\n');
189                    payload.push_str(payload_line);
190                    merged_payload = Some(payload);
191                } else {
192                    first_data_line = Some(payload_line);
193                }
194            }
195        }
196
197        if let Some(payload) = merged_payload {
198            return Self::parse_event(&payload);
199        }
200
201        if let Some(payload) = first_data_line {
202            return Self::parse_event(payload);
203        }
204
205        Ok(None)
206    }
207
208    fn find_event_separator(buffer: &str) -> Option<(usize, usize)> {
209        let bytes = buffer.as_bytes();
210        let mut i = 0usize;
211        while i + 1 < bytes.len() {
212            if bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
213                return Some((i, 2));
214            }
215
216            if bytes[i] == b'\r' && bytes[i + 1] == b'\r' {
217                return Some((i, 2));
218            }
219
220            if i + 3 < bytes.len()
221                && bytes[i] == b'\r'
222                && bytes[i + 1] == b'\n'
223                && bytes[i + 2] == b'\r'
224                && bytes[i + 3] == b'\n'
225            {
226                return Some((i, 4));
227            }
228
229            if i + 2 < bytes.len() {
230                // Mixed line-ending separators can appear in intermediary proxies.
231                if bytes[i] == b'\r' && bytes[i + 1] == b'\n' && bytes[i + 2] == b'\n' {
232                    return Some((i, 3));
233                }
234                if bytes[i] == b'\n' && bytes[i + 1] == b'\r' && bytes[i + 2] == b'\n' {
235                    return Some((i, 3));
236                }
237            }
238
239            i += 1;
240        }
241        None
242    }
243}
244
245impl Stream for ResponseStream {
246    type Item = Result<StreamChunk>;
247
248    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249        let mut this = self.project();
250
251        if *this.done {
252            return Poll::Ready(None);
253        }
254
255        loop {
256            // Check if we have a complete event in the buffer
257            if let Some((pos, sep_len)) = Self::find_event_separator(this.buffer) {
258                let parsed = {
259                    let event_str = &this.buffer[..pos];
260                    Self::parse_sse_event(event_str)
261                };
262                this.buffer.drain(..pos + sep_len);
263
264                match parsed {
265                    Ok(Some(chunk)) => {
266                        if chunk.done {
267                            *this.done = true;
268                        }
269                        return Poll::Ready(Some(Ok(chunk)));
270                    }
271                    Ok(None) => continue,
272                    Err(e) => return Poll::Ready(Some(Err(e))),
273                }
274            }
275
276            // Need more data
277            match this.inner.as_mut().poll_next(cx) {
278                Poll::Ready(Some(Ok(bytes))) => {
279                    this.raw_buffer.extend_from_slice(&bytes);
280                    // Find the longest valid UTF-8 prefix
281                    match std::str::from_utf8(this.raw_buffer) {
282                        Ok(text) => {
283                            this.buffer.push_str(text);
284                            this.raw_buffer.clear();
285                        }
286                        Err(e) => {
287                            let valid_up_to = e.valid_up_to();
288                            if valid_up_to > 0 {
289                                // Safety: we just validated this prefix is valid UTF-8
290                                let valid = std::str::from_utf8(&this.raw_buffer[..valid_up_to])
291                                    .expect("valid_up_to guarantees valid UTF-8");
292                                this.buffer.push_str(valid);
293                                this.raw_buffer.drain(..valid_up_to);
294                            }
295                            // Remaining bytes may be a partial multi-byte char; wait for more data
296                        }
297                    }
298                }
299                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
300                Poll::Ready(None) => {
301                    *this.done = true;
302                    return Poll::Ready(None);
303                }
304                Poll::Pending => return Poll::Pending,
305            }
306        }
307    }
308}
309
310/// SSE event types from the API.
311#[derive(Debug, serde::Deserialize)]
312#[serde(tag = "type", rename_all = "snake_case")]
313enum StreamEvent {
314    /// Response content delta.
315    #[serde(rename = "response.output_item.delta")]
316    ResponseDelta(ResponseDeltaEvent),
317    /// Response complete.
318    #[serde(rename = "response.done")]
319    ResponseDone(ResponseDoneEvent),
320    /// Tool call delta.
321    #[serde(rename = "response.function_call_arguments.delta")]
322    ResponseToolCallDelta(ToolCallDeltaEvent),
323    /// Response created.
324    #[serde(rename = "response.created")]
325    ResponseCreated {},
326    /// Output item added.
327    #[serde(rename = "response.output_item.added")]
328    OutputItemAdded {},
329    /// Output item done.
330    #[serde(rename = "response.output_item.done")]
331    OutputItemDone {},
332    /// Content part added.
333    #[serde(rename = "response.content_part.added")]
334    ContentPartAdded {},
335    /// Content part done.
336    #[serde(rename = "response.content_part.done")]
337    ContentPartDone {},
338    /// Unknown event type.
339    #[serde(other)]
340    Unknown,
341}
342
343#[derive(Debug, serde::Deserialize)]
344struct ResponseDeltaEvent {
345    delta: Option<DeltaContent>,
346}
347
348#[derive(Debug, serde::Deserialize)]
349struct DeltaContent {
350    content: Option<Vec<DeltaContentPart>>,
351}
352
353#[derive(Debug, serde::Deserialize)]
354#[serde(tag = "type", rename_all = "snake_case")]
355enum DeltaContentPart {
356    Text {
357        text: String,
358    },
359    #[serde(other)]
360    Other,
361}
362
363#[derive(Debug, serde::Deserialize)]
364struct ResponseDoneEvent {
365    response: Response,
366}
367
368#[derive(Debug, serde::Deserialize)]
369struct ToolCallDeltaEvent {
370    tool_call_id: Option<String>,
371    delta: Option<ToolCallDelta>,
372}
373
374#[derive(Debug, serde::Deserialize)]
375struct ToolCallDelta {
376    function: Option<crate::models::tool::FunctionCall>,
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use bytes::Bytes;
383    use futures_util::{stream, StreamExt};
384
385    // ── parse_event with valid JSON ────────────────────────────────────
386
387    #[test]
388    fn parse_event_response_delta_text() {
389        let data = r#"{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hello"}]}}"#;
390        let result = ResponseStream::parse_event(data).unwrap();
391        let chunk = result.unwrap();
392        assert!(!chunk.done);
393        assert_eq!(chunk.delta.as_deref(), Some("Hello"));
394        assert!(chunk.response.is_none());
395    }
396
397    #[test]
398    fn parse_event_response_delta_multiple_text_parts_concatenates() {
399        let data = r#"{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hel"},{"type":"text","text":"lo"}]}}"#;
400        let result = ResponseStream::parse_event(data).unwrap();
401        let chunk = result.unwrap();
402        assert!(!chunk.done);
403        assert_eq!(chunk.delta.as_deref(), Some("Hello"));
404    }
405
406    #[test]
407    fn parse_event_response_delta_no_content() {
408        // A delta event with no content field
409        let data = r#"{"type":"response.output_item.delta","delta":{}}"#;
410        let result = ResponseStream::parse_event(data).unwrap();
411        let chunk = result.unwrap();
412        assert!(!chunk.done);
413        assert!(chunk.delta.is_none());
414    }
415
416    #[test]
417    fn parse_event_response_done() {
418        let data = r#"{"type":"response.done","response":{"id":"resp_1","model":"grok-4","output":[{"type":"message","role":"assistant","content":[{"type":"text","text":"Done"}]}],"usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15}}}"#;
419        let result = ResponseStream::parse_event(data).unwrap();
420        let chunk = result.unwrap();
421        assert!(chunk.done);
422        assert!(chunk.response.is_some());
423        let resp = chunk.response.unwrap();
424        assert_eq!(resp.id, "resp_1");
425        assert_eq!(resp.output_text().unwrap(), "Done");
426    }
427
428    #[test]
429    fn parse_event_tool_call_delta() {
430        let data = r#"{"type":"response.function_call_arguments.delta","tool_call_id":"call_1","delta":{"function":{"name":"get_weather","arguments":"{\"city\":"}}}"#;
431        let result = ResponseStream::parse_event(data).unwrap();
432        let chunk = result.unwrap();
433        assert!(!chunk.done);
434        assert_eq!(chunk.tool_calls.len(), 1);
435        assert_eq!(chunk.tool_calls[0].id, "call_1");
436        assert_eq!(chunk.tool_calls[0].call_type.as_deref(), Some("function"));
437        assert_eq!(
438            chunk.tool_calls[0].function.as_ref().unwrap().name,
439            "get_weather"
440        );
441    }
442
443    // ── parse_event with [DONE] marker ─────────────────────────────────
444
445    #[test]
446    fn parse_event_done_marker() {
447        let result = ResponseStream::parse_event("[DONE]").unwrap();
448        let chunk = result.unwrap();
449        assert!(chunk.done);
450        assert!(chunk.delta.is_none());
451        assert!(chunk.response.is_none());
452        assert!(chunk.tool_calls.is_empty());
453    }
454
455    #[test]
456    fn parse_event_done_marker_trims_whitespace() {
457        let result = ResponseStream::parse_event("  [DONE]\n").unwrap();
458        let chunk = result.unwrap();
459        assert!(chunk.done);
460    }
461
462    // ── parse_event with empty/comment lines ──────────────────────────
463
464    #[test]
465    fn parse_event_empty_string() {
466        let result = ResponseStream::parse_event("").unwrap();
467        assert!(result.is_none());
468    }
469
470    #[test]
471    fn parse_event_comment_line() {
472        let result = ResponseStream::parse_event(": this is a comment").unwrap();
473        assert!(result.is_none());
474    }
475
476    #[test]
477    fn parse_event_comment_colon_only() {
478        let result = ResponseStream::parse_event(":").unwrap();
479        assert!(result.is_none());
480    }
481
482    // ── parse_event with unknown event types ──────────────────────────
483
484    #[test]
485    fn parse_event_unknown_type_returns_none() {
486        let data = r#"{"type":"response.created"}"#;
487        let result = ResponseStream::parse_event(data).unwrap();
488        assert!(result.is_none());
489    }
490
491    #[test]
492    fn parse_event_content_part_added_returns_none() {
493        let data = r#"{"type":"response.content_part.added"}"#;
494        let result = ResponseStream::parse_event(data).unwrap();
495        assert!(result.is_none());
496    }
497
498    // ── parse_event with invalid JSON ─────────────────────────────────
499
500    #[test]
501    fn parse_event_invalid_json() {
502        let result = ResponseStream::parse_event("{not valid json}");
503        assert!(result.is_err());
504    }
505
506    #[test]
507    fn parse_event_completely_broken() {
508        let result = ResponseStream::parse_event("just random text");
509        assert!(result.is_err());
510    }
511
512    // ── parse_event tool call without tool_call_id ────────────────────
513
514    #[test]
515    fn parse_event_tool_call_delta_no_id() {
516        let data = r#"{"type":"response.function_call_arguments.delta","delta":{"function":{"name":"fn1","arguments":"{}"}}}"#;
517        let result = ResponseStream::parse_event(data).unwrap();
518        let chunk = result.unwrap();
519        assert_eq!(chunk.tool_calls.len(), 1);
520        // tool_call_id defaults to empty string when None
521        assert_eq!(chunk.tool_calls[0].id, "");
522    }
523
524    #[test]
525    fn parse_event_tool_call_delta_no_delta() {
526        let data = r#"{"type":"response.function_call_arguments.delta","tool_call_id":"call_2"}"#;
527        let result = ResponseStream::parse_event(data).unwrap();
528        let chunk = result.unwrap();
529        // No delta means no tool call in the vec
530        assert!(chunk.tool_calls.is_empty());
531    }
532
533    #[test]
534    fn parse_sse_event_data_without_space() {
535        let event = r#"data:{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hello"}]}}"#;
536        let result = ResponseStream::parse_sse_event(event).unwrap();
537        let chunk = result.unwrap();
538        assert_eq!(chunk.delta.as_deref(), Some("Hello"));
539    }
540
541    #[test]
542    fn parse_sse_event_multiline_data_concatenates_with_newline() {
543        let event = "data: {\"type\":\"response.output_item.delta\",\n\
544data: \"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}";
545        let result = ResponseStream::parse_sse_event(event).unwrap();
546        let chunk = result.unwrap();
547        assert_eq!(chunk.delta.as_deref(), Some("Hello"));
548    }
549
550    #[test]
551    fn parse_sse_event_accepts_bare_data_line_before_done() {
552        let event = "data\ndata: [DONE]";
553        let result = ResponseStream::parse_sse_event(event).unwrap();
554        let chunk = result.unwrap();
555        assert!(chunk.done);
556    }
557
558    #[tokio::test]
559    async fn stream_handles_crlf_event_separators() {
560        let payload = concat!(
561            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\n\r\n",
562            "data: [DONE]\r\n\r\n"
563        );
564
565        let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
566            vec![Ok(Bytes::from(payload.to_string()))];
567        let raw_stream = stream::iter(chunks);
568        let mut response_stream = ResponseStream::new(raw_stream);
569
570        let first = response_stream.next().await.unwrap().unwrap();
571        assert_eq!(first.delta.as_deref(), Some("Hello"));
572        assert!(!first.done);
573
574        let done = response_stream.next().await.unwrap().unwrap();
575        assert!(done.done);
576
577        assert!(response_stream.next().await.is_none());
578    }
579
580    #[tokio::test]
581    async fn stream_handles_cr_only_event_separators() {
582        let payload = concat!(
583            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\r",
584            "data: [DONE]\r\r"
585        );
586
587        let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
588            vec![Ok(Bytes::from(payload.to_string()))];
589        let raw_stream = stream::iter(chunks);
590        let mut response_stream = ResponseStream::new(raw_stream);
591
592        let first = response_stream.next().await.unwrap().unwrap();
593        assert_eq!(first.delta.as_deref(), Some("Hello"));
594        assert!(!first.done);
595
596        let done = response_stream.next().await.unwrap().unwrap();
597        assert!(done.done);
598
599        assert!(response_stream.next().await.is_none());
600    }
601
602    #[tokio::test]
603    async fn stream_handles_mixed_event_separators() {
604        let payload = concat!(
605            "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\n\n",
606            "data: [DONE]\n\r\n"
607        );
608
609        let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
610            vec![Ok(Bytes::from(payload.to_string()))];
611        let raw_stream = stream::iter(chunks);
612        let mut response_stream = ResponseStream::new(raw_stream);
613
614        let first = response_stream.next().await.unwrap().unwrap();
615        assert_eq!(first.delta.as_deref(), Some("Hello"));
616        assert!(!first.done);
617
618        let done = response_stream.next().await.unwrap().unwrap();
619        assert!(done.done);
620
621        assert!(response_stream.next().await.is_none());
622    }
623}