Skip to main content

tool_parser/parsers/
step3.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use openai_protocol::common::Tool;
5use regex::Regex;
6use serde_json::Value;
7
8use crate::{
9    errors::{ParserError, ParserResult},
10    parsers::helpers,
11    traits::ToolParser,
12    types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
13};
14
15/// Step3 format parser for tool calls
16///
17/// Handles the Step3 specific format with steptml XML:
18/// `<|tool_calls_begin|><|tool_call_begin|>function<|tool_sep|><steptml:invoke name="{name}"><steptml:parameter name="{k}">{v}</steptml:parameter></steptml:invoke><|tool_call_end|><|tool_calls_end|>`
19///
20/// Features:
21/// - Unicode token delimiters
22/// - StepTML XML format for invocations
23/// - Support for multiple sequential tool calls
24pub struct Step3Parser {
25    /// Regex for extracting tool call blocks
26    tool_call_extractor: Regex,
27    /// Regex for extracting steptml invocations
28    invoke_extractor: Regex,
29    /// Regex for extracting parameters
30    param_extractor: Regex,
31
32    /// Buffer for accumulating chunks
33    buffer: String,
34
35    /// Token configuration
36    bot_token: &'static str,
37    eot_token: &'static str,
38    tool_call_begin: &'static str,
39    tool_call_end: &'static str,
40    tool_sep: &'static str,
41
42    /// Streaming state variables (mirrors Python's Step3Detector)
43    in_tool_block: bool,
44    tool_block_finished: bool,
45    current_function_name: String,
46    current_parameters: serde_json::Map<String, Value>,
47    in_tool_call: bool,
48    function_name_sent: bool,
49
50    /// Standard state machine fields
51    prev_tool_call_arr: Vec<Value>,
52    current_tool_id: i32,
53    streamed_args_for_tool: Vec<String>,
54}
55
56impl Step3Parser {
57    /// Create a new Step3 parser
58    #[expect(
59        clippy::expect_used,
60        reason = "regex patterns are compile-time string literals"
61    )]
62    pub fn new() -> Self {
63        // Pattern for individual tool calls
64        let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
65        let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
66
67        // Pattern for steptml invocations
68        let invoke_pattern = r#"(?s)<steptml:invoke name="([^"]+)">(.+?)</steptml:invoke>"#;
69        let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
70
71        // Pattern for steptml parameters - using non-greedy match for values to handle < characters
72        let param_pattern = r#"(?s)<steptml:parameter name="([^"]+)">(.+?)</steptml:parameter>"#;
73        let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
74
75        Self {
76            tool_call_extractor,
77            invoke_extractor,
78            param_extractor,
79
80            buffer: String::new(),
81
82            bot_token: "<|tool_calls_begin|>",
83            eot_token: "<|tool_calls_end|>",
84            tool_call_begin: "<|tool_call_begin|>",
85            tool_call_end: "<|tool_call_end|>",
86            tool_sep: "<|tool_sep|>",
87
88            // Streaming state variables
89            in_tool_block: false,
90            tool_block_finished: false,
91            current_function_name: String::new(),
92            current_parameters: serde_json::Map::new(),
93            in_tool_call: false,
94            function_name_sent: false,
95
96            // Standard state machine fields
97            prev_tool_call_arr: Vec::new(),
98            current_tool_id: -1,
99            streamed_args_for_tool: Vec::new(),
100        }
101    }
102
103    /// Reset streaming state for the next tool call
104    fn reset_streaming_state(&mut self) {
105        self.in_tool_call = false;
106        self.function_name_sent = false;
107        self.current_function_name.clear();
108        self.current_parameters.clear();
109    }
110
111    /// Parse partial tool call for streaming scenarios (mirrors Python's _parse_partial_tool_call)
112    fn parse_partial_tool_call(
113        &mut self,
114        tool_indices: &HashMap<String, usize>,
115    ) -> StreamingParseResult {
116        let mut calls = Vec::new();
117
118        // Check if we have tool_sep (means we're past the type declaration)
119        if !self.buffer.contains(self.tool_sep) {
120            return StreamingParseResult {
121                normal_text: String::new(),
122                calls,
123            };
124        }
125
126        // Clone the buffer to avoid borrow conflicts
127        let buffer_clone = self.buffer.clone();
128        let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
129        if parts.len() != 2 {
130            return StreamingParseResult {
131                normal_text: String::new(),
132                calls,
133            };
134        }
135
136        let type_part = parts[0].trim();
137        let invoke_part = parts[1];
138
139        // Check if it's a function type
140        if type_part != "function" {
141            // Invalid tool type, skip this tool call
142            self.reset_streaming_state();
143            return StreamingParseResult {
144                normal_text: String::new(),
145                calls,
146            };
147        }
148
149        // Try to extract function name if not sent yet
150        if !self.function_name_sent {
151            if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
152                let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
153
154                // Validate function name
155                if tool_indices.contains_key(func_name) {
156                    self.current_function_name = func_name.to_string();
157                    self.function_name_sent = true;
158
159                    // Initialize tool tracking
160                    if self.current_tool_id == -1 {
161                        self.current_tool_id = 0;
162                    }
163
164                    // Ensure tracking arrays are large enough
165                    helpers::ensure_capacity(
166                        self.current_tool_id,
167                        &mut self.prev_tool_call_arr,
168                        &mut self.streamed_args_for_tool,
169                    );
170
171                    // Store tool call info
172                    let tool_id = self.current_tool_id as usize;
173                    self.prev_tool_call_arr[tool_id] = serde_json::json!({
174                        "name": func_name,
175                        "arguments": {},
176                    });
177
178                    // Send tool name with empty parameters
179                    calls.push(ToolCallItem {
180                        tool_index: self.current_tool_id as usize,
181                        name: Some(func_name.to_string()),
182                        parameters: String::new(),
183                    });
184                } else {
185                    // Invalid function name
186                    tracing::debug!("Invalid function name: {}", func_name);
187                    self.reset_streaming_state();
188                    return StreamingParseResult {
189                        normal_text: String::new(),
190                        calls,
191                    };
192                }
193            } else {
194                // Function name not complete yet
195                return StreamingParseResult {
196                    normal_text: String::new(),
197                    calls,
198                };
199            }
200        }
201
202        // Parse parameters incrementally
203        if self.function_name_sent {
204            // Extract all complete parameters
205            let mut new_params = serde_json::Map::new();
206            for capture in self.param_extractor.captures_iter(invoke_part) {
207                let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
208                let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
209
210                // Try to parse the value as JSON first, fallback to string
211                let param_value =
212                    if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
213                        json_val
214                    } else {
215                        // Try parsing as Python literal
216                        if param_value_str == "true" || param_value_str == "True" {
217                            Value::Bool(true)
218                        } else if param_value_str == "false" || param_value_str == "False" {
219                            Value::Bool(false)
220                        } else if param_value_str == "null" || param_value_str == "None" {
221                            Value::Null
222                        } else if let Ok(num) = param_value_str.parse::<i64>() {
223                            Value::Number(num.into())
224                        } else if let Ok(num) = param_value_str.parse::<f64>() {
225                            if let Some(n) = serde_json::Number::from_f64(num) {
226                                Value::Number(n)
227                            } else {
228                                Value::String(param_value_str.to_string())
229                            }
230                        } else {
231                            Value::String(param_value_str.to_string())
232                        }
233                    };
234
235                new_params.insert(param_name.to_string(), param_value);
236            }
237
238            // Check if we have new parameters to stream
239            if new_params != self.current_parameters {
240                // Build the JSON content without the closing brace for streaming
241                let diff = if self.current_parameters.is_empty() {
242                    // First parameters - send opening brace and content
243                    let params_content =
244                        serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
245                    if params_content.len() > 2 {
246                        // Send everything except the closing brace
247                        params_content[..params_content.len() - 1].to_string()
248                    } else {
249                        "{".to_string()
250                    }
251                } else {
252                    // Subsequent parameters - calculate the incremental diff
253                    let old_json = serde_json::to_string(&self.current_parameters)
254                        .unwrap_or_else(|_| "{}".to_string());
255                    let new_json =
256                        serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
257
258                    // Remove closing braces for comparison
259                    let old_without_brace = &old_json[..old_json.len() - 1];
260                    let new_without_brace = &new_json[..new_json.len() - 1];
261
262                    // The new content should extend the old content
263                    new_without_brace
264                        .strip_prefix(old_without_brace)
265                        .map(|s| s.to_string())
266                        .unwrap_or_default()
267                };
268
269                if !diff.is_empty() {
270                    calls.push(ToolCallItem {
271                        tool_index: self.current_tool_id as usize,
272                        name: None,
273                        parameters: diff.clone(),
274                    });
275                    let tool_id = self.current_tool_id as usize;
276                    if tool_id < self.streamed_args_for_tool.len() {
277                        self.streamed_args_for_tool[tool_id].push_str(&diff);
278                    }
279                }
280
281                // Update current state
282                self.current_parameters.clone_from(&new_params);
283                let tool_id = self.current_tool_id as usize;
284                if tool_id < self.prev_tool_call_arr.len() {
285                    if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
286                        obj.insert("arguments".to_string(), Value::Object(new_params));
287                    }
288                }
289            }
290
291            // Check if tool call is complete
292            if self.buffer.contains(self.tool_call_end) {
293                // Send closing brace if we've sent any parameters
294                let tool_id = self.current_tool_id as usize;
295                if tool_id < self.streamed_args_for_tool.len()
296                    && !self.streamed_args_for_tool[tool_id].is_empty()
297                {
298                    calls.push(ToolCallItem {
299                        tool_index: self.current_tool_id as usize,
300                        name: None,
301                        parameters: "}".to_string(),
302                    });
303                    self.streamed_args_for_tool[tool_id].push('}');
304                }
305
306                // Find the end position
307                if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
308                    // Remove the processed tool call from buffer
309                    self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
310                }
311
312                // Reset state for next tool call
313                self.reset_streaming_state();
314                self.current_tool_id += 1;
315            }
316        }
317
318        StreamingParseResult {
319            normal_text: String::new(),
320            calls,
321        }
322    }
323
324    /// Parse parameters from steptml format
325    fn parse_steptml_parameters(&self, params_text: &str) -> serde_json::Map<String, Value> {
326        let mut parameters = serde_json::Map::new();
327
328        for capture in self.param_extractor.captures_iter(params_text) {
329            let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
330            let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
331
332            // Try to parse the value as JSON first, fallback to string
333            let param_value = if let Ok(json_val) = serde_json::from_str::<Value>(param_value_str) {
334                json_val
335            } else {
336                // Try parsing as Python literal
337                if param_value_str == "true" || param_value_str == "True" {
338                    Value::Bool(true)
339                } else if param_value_str == "false" || param_value_str == "False" {
340                    Value::Bool(false)
341                } else if param_value_str == "null" || param_value_str == "None" {
342                    Value::Null
343                } else if let Ok(num) = param_value_str.parse::<i64>() {
344                    Value::Number(num.into())
345                } else if let Ok(num) = param_value_str.parse::<f64>() {
346                    if let Some(n) = serde_json::Number::from_f64(num) {
347                        Value::Number(n)
348                    } else {
349                        Value::String(param_value_str.to_string())
350                    }
351                } else {
352                    Value::String(param_value_str.to_string())
353                }
354            };
355
356            parameters.insert(param_name.to_string(), param_value);
357        }
358
359        parameters
360    }
361
362    /// Parse a single tool call block
363    fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
364        // Check if it contains function marker and tool separator
365        if !block.contains("function") || !block.contains("<|tool_sep|>") {
366            return Ok(None);
367        }
368
369        // Split by tool separator
370        let parts: Vec<&str> = block.split("<|tool_sep|>").collect();
371        if parts.len() != 2 {
372            return Ok(None);
373        }
374
375        // Check if it's a function type
376        if !parts[0].contains("function") {
377            return Ok(None);
378        }
379
380        let invoke_part = parts[1];
381
382        // Extract steptml invoke
383        if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
384            let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
385
386            // Validate function name is not empty
387            if func_name.is_empty() {
388                return Ok(None);
389            }
390
391            let params_text = captures.get(2).map_or("", |m| m.as_str());
392
393            // Parse parameters
394            let parameters = self.parse_steptml_parameters(params_text);
395
396            let arguments_str = serde_json::to_string(&parameters)
397                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
398
399            Ok(Some(ToolCall {
400                function: FunctionCall {
401                    name: func_name.to_string(),
402                    arguments: arguments_str,
403                },
404            }))
405        } else {
406            Ok(None)
407        }
408    }
409}
410
411impl Default for Step3Parser {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417#[async_trait]
418impl ToolParser for Step3Parser {
419    async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
420        if !self.has_tool_markers(text) {
421            return Ok((text.to_string(), vec![]));
422        }
423
424        // Find where tool calls begin
425        // Safe: has_tool_markers() already confirmed the marker exists
426        let idx = text
427            .find("<|tool_calls_begin|>")
428            .ok_or_else(|| ParserError::ParsingFailed("tool call marker not found".to_string()))?;
429        let normal_text = text[..idx].to_string();
430
431        // Extract tool calls
432        let mut tools = Vec::new();
433        for mat in self.tool_call_extractor.find_iter(text) {
434            match self.parse_tool_call(mat.as_str()) {
435                Ok(Some(tool)) => tools.push(tool),
436                Ok(None) => continue,
437                Err(e) => {
438                    tracing::debug!("Failed to parse tool call: {}", e);
439                    continue;
440                }
441            }
442        }
443
444        // If no tools were successfully parsed despite having markers, return entire text as fallback
445        if tools.is_empty() {
446            return Ok((text.to_string(), vec![]));
447        }
448
449        Ok((normal_text, tools))
450    }
451
452    async fn parse_incremental(
453        &mut self,
454        chunk: &str,
455        tools: &[Tool],
456    ) -> ParserResult<StreamingParseResult> {
457        self.buffer.push_str(chunk);
458
459        // Build tool indices for validation
460        let tool_indices = helpers::get_tool_indices(tools);
461
462        // Stage 1: If we've finished the tool block, everything is normal text
463        if self.tool_block_finished {
464            let normal_text = std::mem::take(&mut self.buffer);
465            return Ok(StreamingParseResult {
466                normal_text,
467                calls: vec![],
468            });
469        }
470
471        // Stage 2: Check if tool block hasn't started yet
472        if !self.in_tool_block {
473            if self.buffer.contains(self.bot_token) {
474                // Safe: contains() confirmed the token exists
475                let idx = self.buffer.find(self.bot_token).ok_or_else(|| {
476                    ParserError::ParsingFailed("token not found in buffer".to_string())
477                })?;
478                let normal_text = self.buffer[..idx].to_string();
479                self.buffer = self.buffer[idx + self.bot_token.len()..].to_string();
480                self.in_tool_block = true;
481                return Ok(StreamingParseResult {
482                    normal_text,
483                    calls: vec![],
484                });
485            } else {
486                // Check if we might have a partial bot_token
487                if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_some() {
488                    return Ok(StreamingParseResult::default()); // Wait for more text
489                } else {
490                    let normal_text = std::mem::take(&mut self.buffer);
491                    return Ok(StreamingParseResult {
492                        normal_text,
493                        calls: vec![],
494                    });
495                }
496            }
497        }
498
499        // We're inside the tool block
500        let mut calls = Vec::new();
501
502        // Stage 3: Check if tool block is ending
503        if self.buffer.contains(self.eot_token) {
504            // Safe: contains() confirmed the token exists
505            let idx = self.buffer.find(self.eot_token).ok_or_else(|| {
506                ParserError::ParsingFailed("token not found in buffer".to_string())
507            })?;
508
509            // If we're in the middle of a tool call, we need to handle it
510            if self.in_tool_call {
511                // The buffer before eot_token might contain the end of the current tool call
512                let before_eot = &self.buffer[..idx];
513                if before_eot.contains(self.tool_call_end) {
514                    // Parse this final tool call
515                    let result = self.parse_partial_tool_call(&tool_indices);
516                    calls.extend(result.calls);
517                } else {
518                    // Incomplete tool call - log warning
519                    tracing::warn!("Tool block ended with incomplete tool call");
520                }
521            }
522
523            let remaining = self.buffer[idx + self.eot_token.len()..].to_string();
524            self.buffer.clear();
525            self.tool_block_finished = true;
526
527            // Reset any partial tool call state
528            self.reset_streaming_state();
529
530            return Ok(StreamingParseResult {
531                normal_text: remaining,
532                calls,
533            });
534        }
535
536        // Stage 4: Check if we're in a tool call or need to start one
537        if !self.in_tool_call {
538            if self.buffer.contains(self.tool_call_begin) {
539                // Safe: contains() confirmed the token exists
540                let idx = self.buffer.find(self.tool_call_begin).ok_or_else(|| {
541                    ParserError::ParsingFailed("token not found in buffer".to_string())
542                })?;
543                // Remove any content before tool call begin (shouldn't happen but be safe)
544                self.buffer = self.buffer[idx + self.tool_call_begin.len()..].to_string();
545                self.in_tool_call = true;
546                self.function_name_sent = false;
547                self.current_function_name.clear();
548                self.current_parameters.clear();
549                // Fall through to parse the partial tool call
550            } else {
551                // Wait for tool call to begin
552                return Ok(StreamingParseResult::default());
553            }
554        }
555
556        // Stage 5: Parse partial tool call
557        if self.in_tool_call {
558            return Ok(self.parse_partial_tool_call(&tool_indices));
559        }
560
561        Ok(StreamingParseResult::default())
562    }
563
564    fn has_tool_markers(&self, text: &str) -> bool {
565        text.contains(self.bot_token)
566    }
567
568    fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
569        helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
570    }
571
572    fn reset(&mut self) {
573        // Reset standard state
574        self.buffer.clear();
575        self.prev_tool_call_arr.clear();
576        self.current_tool_id = -1;
577        self.streamed_args_for_tool.clear();
578
579        // Reset Step3-specific fields
580        self.in_tool_block = false;
581        self.tool_block_finished = false;
582        self.current_function_name.clear();
583        self.current_parameters.clear();
584        self.in_tool_call = false;
585        self.function_name_sent = false;
586    }
587}