struct_llm/
streaming.rs

1/// Streaming parser for incremental tool call parsing from SSE responses
2///
3/// This module enables real-time processing of tool calls as they stream
4/// from LLM APIs, without waiting for the complete response.
5use crate::{error::Result, Provider, ToolCall};
6use serde::Deserialize;
7
8/// Delta updates during streaming tool call construction
9#[derive(Debug, Clone)]
10pub enum ToolDelta {
11    /// Tool call has started
12    Start { id: String, name: String },
13    /// New arguments data (may be partial JSON)
14    Arguments { delta: String },
15    /// Tool call is complete
16    End,
17}
18
19/// Parser for streaming SSE responses
20///
21/// Handles incremental parsing of tool calls from Server-Sent Events (SSE)
22/// streams. Accumulates partial data and emits deltas as they arrive.
23///
24/// # Example
25///
26/// ```ignore
27/// let mut parser = StreamParser::new(Provider::Anthropic);
28///
29/// // Process each SSE chunk
30/// for chunk in sse_stream {
31///     if let Some(delta) = parser.parse_chunk(&chunk)? {
32///         match delta {
33///             ToolDelta::Start { name, .. } => println!("Starting: {}", name),
34///             ToolDelta::Arguments { delta } => print!("{}", delta),
35///             ToolDelta::End => println!("\nDone!"),
36///         }
37///     }
38/// }
39///
40/// // Get the complete tool call
41/// let tool_call = parser.finalize()?;
42/// ```
43pub struct StreamParser {
44    provider: Provider,
45    state: ParserState,
46}
47
48#[derive(Debug, Default)]
49struct ParserState {
50    current_tool_id: Option<String>,
51    current_tool_name: Option<String>,
52    accumulated_args: String,
53    is_complete: bool,
54}
55
56impl StreamParser {
57    /// Create a new streaming parser for the specified provider
58    pub fn new(provider: Provider) -> Self {
59        Self {
60            provider,
61            state: ParserState::default(),
62        }
63    }
64
65    /// Parse a single SSE chunk and return any deltas
66    ///
67    /// Returns `None` if the chunk doesn't contain relevant tool call data.
68    /// Returns `Some(ToolDelta)` when tool call events occur.
69    pub fn parse_chunk(&mut self, chunk: &str) -> Result<Option<ToolDelta>> {
70        match self.provider {
71            Provider::OpenAI => self.parse_openai_chunk(chunk),
72            Provider::Anthropic => self.parse_anthropic_chunk(chunk),
73            Provider::Local => self.parse_local_chunk(chunk),
74        }
75    }
76
77    /// Get the final complete tool call after streaming is done
78    ///
79    /// This should be called after all chunks have been processed.
80    pub fn finalize(self) -> Result<ToolCall> {
81        if !self.state.is_complete {
82            return Err(crate::error::Error::InvalidResponseFormat(
83                "Tool call stream not completed".to_string(),
84            ));
85        }
86
87        let id = self
88            .state
89            .current_tool_id
90            .ok_or(crate::error::Error::NoToolCalls)?;
91        let name = self
92            .state
93            .current_tool_name
94            .ok_or(crate::error::Error::NoToolCalls)?;
95
96        // Parse accumulated arguments as JSON
97        let arguments: serde_json::Value = if self.state.accumulated_args.is_empty() {
98            serde_json::json!({})
99        } else {
100            serde_json::from_str(&self.state.accumulated_args)?
101        };
102
103        Ok(ToolCall {
104            id,
105            name,
106            arguments,
107        })
108    }
109
110    fn parse_openai_chunk(&mut self, chunk: &str) -> Result<Option<ToolDelta>> {
111        // OpenAI SSE format: "data: {json}\n\n"
112        let chunk = chunk.trim();
113
114        if chunk.starts_with("data: [DONE]") {
115            if self.state.current_tool_id.is_some() {
116                self.state.is_complete = true;
117                return Ok(Some(ToolDelta::End));
118            }
119            return Ok(None);
120        }
121
122        if !chunk.starts_with("data: ") {
123            return Ok(None);
124        }
125
126        let json_str = chunk.strip_prefix("data: ").unwrap_or(chunk);
127        let chunk_data: serde_json::Value = serde_json::from_str(json_str)?;
128
129        // Look for tool_calls in delta
130        if let Some(choices) = chunk_data["choices"].as_array() {
131            if let Some(choice) = choices.first() {
132                if let Some(delta) = choice["delta"].as_object() {
133                    if let Some(tool_calls) = delta.get("tool_calls") {
134                        if let Some(tool_call) = tool_calls.as_array().and_then(|arr| arr.first()) {
135                            // Check for tool call start
136                            if let Some(id) = tool_call["id"].as_str() {
137                                if let Some(name) = tool_call["function"]["name"].as_str() {
138                                    self.state.current_tool_id = Some(id.to_string());
139                                    self.state.current_tool_name = Some(name.to_string());
140                                    return Ok(Some(ToolDelta::Start {
141                                        id: id.to_string(),
142                                        name: name.to_string(),
143                                    }));
144                                }
145                            }
146
147                            // Check for arguments delta
148                            if let Some(args_delta) = tool_call["function"]["arguments"].as_str() {
149                                if !args_delta.is_empty() {
150                                    self.state.accumulated_args.push_str(args_delta);
151                                    return Ok(Some(ToolDelta::Arguments {
152                                        delta: args_delta.to_string(),
153                                    }));
154                                }
155                            }
156                        }
157                    }
158                }
159            }
160        }
161
162        Ok(None)
163    }
164
165    fn parse_anthropic_chunk(&mut self, chunk: &str) -> Result<Option<ToolDelta>> {
166        // Anthropic SSE format
167        let chunk = chunk.trim();
168
169        if !chunk.starts_with("data: ") {
170            return Ok(None);
171        }
172
173        let json_str = chunk.strip_prefix("data: ").unwrap_or(chunk);
174        let event: AnthropicEvent = serde_json::from_str(json_str)?;
175
176        match event.event_type.as_str() {
177            "content_block_start" => {
178                if let Some(content) = event.content_block {
179                    if content.block_type == "tool_use" {
180                        let id = content.id.unwrap_or_default();
181                        let name = content.name.unwrap_or_default();
182
183                        self.state.current_tool_id = Some(id.clone());
184                        self.state.current_tool_name = Some(name.clone());
185
186                        return Ok(Some(ToolDelta::Start { id, name }));
187                    }
188                }
189            }
190            "content_block_delta" => {
191                if let Some(delta) = event.delta {
192                    if delta.delta_type == "input_json_delta" {
193                        if let Some(partial_json) = delta.partial_json {
194                            self.state.accumulated_args.push_str(&partial_json);
195                            return Ok(Some(ToolDelta::Arguments {
196                                delta: partial_json,
197                            }));
198                        }
199                    }
200                }
201            }
202            "content_block_stop" => {
203                if self.state.current_tool_id.is_some() {
204                    self.state.is_complete = true;
205                    return Ok(Some(ToolDelta::End));
206                }
207            }
208            "message_stop" => {
209                if self.state.current_tool_id.is_some() && !self.state.is_complete {
210                    self.state.is_complete = true;
211                    return Ok(Some(ToolDelta::End));
212                }
213            }
214            _ => {}
215        }
216
217        Ok(None)
218    }
219
220    fn parse_local_chunk(&mut self, chunk: &str) -> Result<Option<ToolDelta>> {
221        // Local/generic format (similar to OpenAI)
222        self.parse_openai_chunk(chunk)
223    }
224}
225
226#[derive(Debug, Deserialize)]
227struct AnthropicEvent {
228    #[serde(rename = "type")]
229    event_type: String,
230    #[serde(default)]
231    content_block: Option<ContentBlock>,
232    #[serde(default)]
233    delta: Option<Delta>,
234}
235
236#[derive(Debug, Deserialize)]
237struct ContentBlock {
238    #[serde(rename = "type")]
239    block_type: String,
240    #[serde(default)]
241    id: Option<String>,
242    #[serde(default)]
243    name: Option<String>,
244}
245
246#[derive(Debug, Deserialize)]
247struct Delta {
248    #[serde(rename = "type")]
249    delta_type: String,
250    #[serde(default)]
251    partial_json: Option<String>,
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_openai_streaming() {
260        let mut parser = StreamParser::new(Provider::OpenAI);
261
262        // Simulate OpenAI SSE chunks
263        let chunks = vec![
264            r#"data: {"choices":[{"delta":{"tool_calls":[{"id":"call_123","function":{"name":"test_tool"}}]}}]}"#,
265            r#"data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"{"}}]}}]}"#,
266            r#"data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"\"key\": \"value\""}}]}}]}"#,
267            r#"data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"}"}}]}}]}"#,
268            "data: [DONE]",
269        ];
270
271        let mut deltas = Vec::new();
272        for chunk in chunks {
273            if let Some(delta) = parser.parse_chunk(chunk).unwrap() {
274                deltas.push(delta);
275            }
276        }
277
278        // Check we got the expected deltas
279        assert!(matches!(deltas[0], ToolDelta::Start { .. }));
280        assert!(matches!(deltas[1], ToolDelta::Arguments { .. }));
281        assert!(matches!(deltas.last(), Some(ToolDelta::End)));
282
283        // Finalize and check result
284        let tool_call = parser.finalize().unwrap();
285        assert_eq!(tool_call.name, "test_tool");
286        assert_eq!(tool_call.id, "call_123");
287    }
288
289    #[test]
290    fn test_anthropic_streaming() {
291        let mut parser = StreamParser::new(Provider::Anthropic);
292
293        // Simulate Anthropic SSE chunks
294        let chunks = vec![
295            r#"data: {"type":"content_block_start","content_block":{"type":"tool_use","id":"toolu_123","name":"test_tool"}}"#,
296            r#"data: {"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":"{\"key\": "}}"#,
297            r#"data: {"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":"\"value\"}"}}"#,
298            r#"data: {"type":"content_block_stop"}"#,
299        ];
300
301        let mut deltas = Vec::new();
302        for chunk in chunks {
303            if let Some(delta) = parser.parse_chunk(chunk).unwrap() {
304                deltas.push(delta);
305            }
306        }
307
308        // Check we got the expected deltas
309        assert!(matches!(deltas[0], ToolDelta::Start { .. }));
310        assert!(matches!(deltas[1], ToolDelta::Arguments { .. }));
311        assert!(matches!(deltas.last(), Some(ToolDelta::End)));
312
313        // Finalize and check result
314        let tool_call = parser.finalize().unwrap();
315        assert_eq!(tool_call.name, "test_tool");
316        assert_eq!(tool_call.id, "toolu_123");
317        assert_eq!(tool_call.arguments["key"], "value");
318    }
319}