Skip to main content

tool_parser/parsers/
helpers.rs

1use std::collections::HashMap;
2
3use openai_protocol::common::Tool;
4use serde_json::Value;
5
6use crate::{
7    errors::{ParserError, ParserResult},
8    types::{StreamingParseResult, ToolCallItem},
9};
10
11/// Get a mapping of tool names to their indices
12pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
13    tools
14        .iter()
15        .enumerate()
16        .map(|(i, tool)| (tool.function.name.clone(), i))
17        .collect()
18}
19
20/// Find the common prefix of two strings
21/// Used for incremental argument streaming when partial JSON returns different intermediate states
22pub fn find_common_prefix(s1: &str, s2: &str) -> String {
23    s1.chars()
24        .zip(s2.chars())
25        .take_while(|(c1, c2)| c1 == c2)
26        .map(|(c1, _)| c1)
27        .collect()
28}
29
30/// Get unstreamed tool call arguments
31/// Returns tool call items for arguments that have been parsed but not yet streamed
32/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
33pub fn get_unstreamed_args(
34    prev_tool_call_arr: &[Value],
35    streamed_args_for_tool: &[String],
36) -> Option<Vec<ToolCallItem>> {
37    // Check if we have tool calls being tracked
38    if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
39        return None;
40    }
41
42    // Get the last tool call that was being processed
43    let tool_index = prev_tool_call_arr.len() - 1;
44    if tool_index >= streamed_args_for_tool.len() {
45        return None;
46    }
47
48    // Get expected vs actual arguments
49    let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
50    let expected_str = serde_json::to_string(expected_args).ok()?;
51    let actual_str = &streamed_args_for_tool[tool_index];
52
53    // Check if there are remaining arguments to send
54    let remaining = if expected_str.starts_with(actual_str) {
55        &expected_str[actual_str.len()..]
56    } else {
57        return None;
58    };
59
60    if remaining.is_empty() {
61        return None;
62    }
63
64    // Return the remaining arguments as a ToolCallItem
65    Some(vec![ToolCallItem {
66        tool_index,
67        name: None, // No name for argument deltas
68        parameters: remaining.to_string(),
69    }])
70}
71
72/// Check if a buffer ends with a partial occurrence of a token
73/// Returns Some(length) if there's a partial match, None otherwise
74pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
75    if buffer.is_empty() || token.is_empty() {
76        return None;
77    }
78
79    (1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
80}
81
82/// Reset state for the current tool being parsed (used when skipping invalid tools).
83/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
84/// but clears the state specific to the current incomplete tool.
85pub fn reset_current_tool_state(
86    buffer: &mut String,
87    current_tool_name_sent: &mut bool,
88    streamed_args_for_tool: &mut Vec<String>,
89    prev_tool_call_arr: &[Value],
90) {
91    buffer.clear();
92    *current_tool_name_sent = false;
93
94    // Only pop if we added an entry for the current (invalid) tool
95    // streamed_args_for_tool should match prev_tool_call_arr length for completed tools
96    if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
97        streamed_args_for_tool.pop();
98    }
99}
100
101/// Reset the entire parser state (used at the start of a new request).
102/// Clears all accumulated tool calls and resets all state to initial values.
103pub fn reset_parser_state(
104    buffer: &mut String,
105    prev_tool_call_arr: &mut Vec<Value>,
106    current_tool_id: &mut i32,
107    current_tool_name_sent: &mut bool,
108    streamed_args_for_tool: &mut Vec<String>,
109) {
110    buffer.clear();
111    prev_tool_call_arr.clear();
112    *current_tool_id = -1;
113    *current_tool_name_sent = false;
114    streamed_args_for_tool.clear();
115}
116
117/// Ensure arrays have capacity for the given tool ID
118pub fn ensure_capacity(
119    current_tool_id: i32,
120    prev_tool_call_arr: &mut Vec<Value>,
121    streamed_args_for_tool: &mut Vec<String>,
122) {
123    if current_tool_id < 0 {
124        return;
125    }
126    let needed = (current_tool_id + 1) as usize;
127
128    if prev_tool_call_arr.len() < needed {
129        prev_tool_call_arr.resize_with(needed, || Value::Null);
130    }
131    if streamed_args_for_tool.len() < needed {
132        streamed_args_for_tool.resize_with(needed, String::new);
133    }
134}
135
136/// Check if a string contains complete, valid JSON
137pub fn is_complete_json(input: &str) -> bool {
138    serde_json::from_str::<Value>(input).is_ok()
139}
140
141/// Normalize the arguments/parameters field in a tool call object.
142/// If the object has "parameters" but not "arguments", copy parameters to arguments.
143///
144/// # Background
145/// Different LLM formats use different field names:
146/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
147/// - Mistral and Qwen use "arguments"
148///
149/// This function normalizes to "arguments" for consistent downstream processing.
150pub fn normalize_arguments_field(mut obj: Value) -> Value {
151    if obj.get("arguments").is_none() {
152        if let Some(params) = obj.get("parameters").cloned() {
153            if let Value::Object(ref mut map) = obj {
154                map.insert("arguments".to_string(), params);
155            }
156        }
157    }
158    obj
159}
160
161/// Normalize the name/tool_name field in a tool call object.
162/// If the object has "tool_name" but not "name", copy tool_name to name.
163///
164/// # Background
165/// Cohere models use "tool_name" instead of "name":
166/// - Standard format uses "name"
167/// - Cohere uses "tool_name"
168///
169/// This function normalizes to "name" for consistent downstream processing.
170pub fn normalize_name_field(mut obj: Value) -> Value {
171    if obj.get("name").is_none() {
172        if let Some(tool_name) = obj.get("tool_name").cloned() {
173            if let Value::Object(ref mut map) = obj {
174                map.insert("name".to_string(), tool_name);
175            }
176        }
177    }
178    obj
179}
180
181/// Normalize all tool call fields (both name and arguments).
182/// Combines normalize_name_field and normalize_arguments_field.
183///
184/// This handles formats like Cohere that use both "tool_name" and "parameters"
185/// instead of the standard "name" and "arguments".
186pub fn normalize_tool_call_fields(obj: Value) -> Value {
187    let obj = normalize_name_field(obj);
188    normalize_arguments_field(obj)
189}
190
191/// Handle the entire JSON tool call streaming process for JSON-based parsers.
192///
193/// This unified function handles all aspects of streaming tool calls:
194/// - Parsing partial JSON from the buffer
195/// - Validating tool names against available tools
196/// - Streaming tool names (Case 1)
197/// - Streaming tool arguments (Case 2)
198/// - Managing parser state and buffer updates
199///
200/// Used by JSON, Llama, Mistral, and Qwen parsers.
201///
202/// # Parameters
203/// - `current_text`: The current buffered text being parsed
204/// - `start_idx`: Start index of JSON content in current_text
205/// - `partial_json`: Mutable reference to partial JSON parser
206/// - `tool_indices`: Map of valid tool names to their indices
207/// - `buffer`: Mutable parser buffer
208/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
209/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
210/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
211/// - `prev_tool_call_arr`: Mutable array of previous tool call states
212///
213/// # Returns
214/// - `Ok(StreamingParseResult)` with any tool call items to stream
215/// - `Err(ParserError)` if JSON parsing or serialization fails
216#[expect(clippy::too_many_arguments)]
217pub(crate) fn handle_json_tool_streaming(
218    current_text: &str,
219    start_idx: usize,
220    partial_json: &mut crate::partial_json::PartialJson,
221    tool_indices: &HashMap<String, usize>,
222    buffer: &mut String,
223    current_tool_id: &mut i32,
224    current_tool_name_sent: &mut bool,
225    streamed_args_for_tool: &mut Vec<String>,
226    prev_tool_call_arr: &mut Vec<Value>,
227) -> ParserResult<StreamingParseResult> {
228    // Check if we have content to parse
229    if start_idx >= current_text.len() {
230        return Ok(StreamingParseResult::default());
231    }
232
233    // Extract JSON string from current position
234    let json_str = &current_text[start_idx..];
235
236    // When current_tool_name_sent is false, don't allow partial strings to avoid
237    // parsing incomplete tool names as empty strings
238    let allow_partial_strings = *current_tool_name_sent;
239
240    // Parse partial JSON
241    let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
242        Ok(result) => result,
243        Err(_) => {
244            return Ok(StreamingParseResult::default());
245        }
246    };
247
248    // Check if JSON is complete - validate only the parsed portion
249    // Ensure end_idx is on a valid UTF-8 character boundary
250    let safe_end_idx = if json_str.is_char_boundary(end_idx) {
251        end_idx
252    } else {
253        // Find the nearest valid character boundary before end_idx
254        (0..end_idx)
255            .rev()
256            .find(|&i| json_str.is_char_boundary(i))
257            .unwrap_or(0)
258    };
259    let is_complete = serde_json::from_str::<Value>(&json_str[..safe_end_idx]).is_ok();
260
261    // Normalize all tool call fields first (handles tool_name -> name, parameters -> arguments)
262    // This must happen before validation since different LLMs use different field names
263    let current_tool_call = normalize_tool_call_fields(obj);
264
265    // Validate tool name if present
266    if let Some(name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
267        if !tool_indices.contains_key(name) {
268            // Invalid tool name - skip this tool, preserve indexing for next tool
269            tracing::debug!("Invalid tool name '{}' - skipping", name);
270            reset_current_tool_state(
271                buffer,
272                current_tool_name_sent,
273                streamed_args_for_tool,
274                prev_tool_call_arr,
275            );
276            return Ok(StreamingParseResult::default());
277        }
278    }
279
280    let mut result = StreamingParseResult::default();
281
282    // Case 1: Handle tool name streaming
283    if !*current_tool_name_sent {
284        if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
285            if tool_indices.contains_key(function_name) {
286                // Initialize if first tool
287                if *current_tool_id == -1 {
288                    *current_tool_id = 0;
289                    streamed_args_for_tool.push(String::new());
290                } else if *current_tool_id as usize >= streamed_args_for_tool.len() {
291                    // Ensure capacity for subsequent tools
292                    ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
293                }
294
295                // Send tool name with empty parameters
296                *current_tool_name_sent = true;
297                result.calls.push(ToolCallItem {
298                    tool_index: *current_tool_id as usize,
299                    name: Some(function_name.to_string()),
300                    parameters: String::new(),
301                });
302            }
303        }
304    }
305    // Case 2: Handle streaming arguments
306    else if let Some(cur_arguments) = current_tool_call.get("arguments") {
307        let tool_id = *current_tool_id as usize;
308        let sent = streamed_args_for_tool
309            .get(tool_id)
310            .map(|s| s.len())
311            .unwrap_or(0);
312        let cur_args_json = serde_json::to_string(cur_arguments)
313            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
314
315        // Get prev_arguments (matches Python's structure)
316        let prev_arguments = if tool_id < prev_tool_call_arr.len() {
317            prev_tool_call_arr[tool_id].get("arguments")
318        } else {
319            None
320        };
321
322        // Calculate diff: everything after we've already sent
323        let mut argument_diff = None;
324
325        if is_complete {
326            // Python: argument_diff = cur_args_json[sent:]
327            // Rust needs bounds check (Python returns "" automatically)
328            argument_diff = if sent < cur_args_json.len() {
329                Some(cur_args_json[sent..].to_string())
330            } else {
331                Some(String::new())
332            };
333        } else if let Some(prev_args) = prev_arguments {
334            let prev_args_json = serde_json::to_string(prev_args)
335                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
336
337            if cur_args_json != prev_args_json {
338                let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
339                argument_diff = if sent < prefix.len() {
340                    Some(prefix[sent..].to_string())
341                } else {
342                    Some(String::new())
343                };
344            }
345        }
346
347        // Send diff if present
348        if let Some(diff) = argument_diff {
349            if !diff.is_empty() {
350                if tool_id < streamed_args_for_tool.len() {
351                    streamed_args_for_tool[tool_id].push_str(&diff);
352                }
353                result.calls.push(ToolCallItem {
354                    tool_index: tool_id,
355                    name: None,
356                    parameters: diff,
357                });
358            }
359        }
360
361        // Update prev_tool_call_arr with current state
362        if *current_tool_id >= 0 {
363            ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
364
365            if tool_id < prev_tool_call_arr.len() {
366                prev_tool_call_arr[tool_id] = current_tool_call;
367            }
368        }
369
370        // If complete, advance to next tool
371        if is_complete {
372            *buffer = current_text[start_idx + end_idx..].to_string();
373            *current_tool_name_sent = false;
374            *current_tool_id += 1;
375        }
376    }
377
378    Ok(result)
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_ends_with_partial_token() {
387        assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
388        assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
389        assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
390        assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
391        assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
392    }
393
394    #[test]
395    fn test_reset_current_tool_state() {
396        let mut buffer = String::from("partial json");
397        let mut current_tool_name_sent = true;
398        let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
399        let prev_tools = vec![serde_json::json!({"name": "tool0"})];
400
401        reset_current_tool_state(
402            &mut buffer,
403            &mut current_tool_name_sent,
404            &mut streamed_args,
405            &prev_tools,
406        );
407
408        assert_eq!(buffer, "");
409        assert!(!current_tool_name_sent);
410        assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
411        assert_eq!(streamed_args[0], "tool0_args");
412    }
413
414    #[test]
415    fn test_reset_current_tool_state_no_pop_when_synced() {
416        let mut buffer = String::from("partial json");
417        let mut current_tool_name_sent = true;
418        let mut streamed_args = vec!["tool0_args".to_string()];
419        let prev_tools = vec![serde_json::json!({"name": "tool0"})];
420
421        reset_current_tool_state(
422            &mut buffer,
423            &mut current_tool_name_sent,
424            &mut streamed_args,
425            &prev_tools,
426        );
427
428        assert_eq!(buffer, "");
429        assert!(!current_tool_name_sent);
430        assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
431    }
432
433    #[test]
434    fn test_reset_parser_state() {
435        let mut buffer = String::from("some buffer");
436        let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
437        let mut current_tool_id = 5;
438        let mut current_tool_name_sent = true;
439        let mut streamed_args = vec!["args".to_string()];
440
441        reset_parser_state(
442            &mut buffer,
443            &mut prev_tools,
444            &mut current_tool_id,
445            &mut current_tool_name_sent,
446            &mut streamed_args,
447        );
448
449        assert_eq!(buffer, "");
450        assert_eq!(prev_tools.len(), 0);
451        assert_eq!(current_tool_id, -1);
452        assert!(!current_tool_name_sent);
453        assert_eq!(streamed_args.len(), 0);
454    }
455
456    #[test]
457    fn test_ensure_capacity() {
458        let mut prev_tools = vec![];
459        let mut streamed_args = vec![];
460
461        ensure_capacity(2, &mut prev_tools, &mut streamed_args);
462
463        assert_eq!(prev_tools.len(), 3);
464        assert_eq!(streamed_args.len(), 3);
465        assert_eq!(prev_tools[0], Value::Null);
466        assert_eq!(streamed_args[0], "");
467    }
468
469    #[test]
470    fn test_ensure_capacity_negative_id() {
471        let mut prev_tools = vec![];
472        let mut streamed_args = vec![];
473
474        ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
475
476        // Should not resize for negative ID
477        assert_eq!(prev_tools.len(), 0);
478        assert_eq!(streamed_args.len(), 0);
479    }
480
481    #[test]
482    fn test_is_complete_json() {
483        assert!(is_complete_json(r#"{"name": "test"}"#));
484        assert!(is_complete_json("[1, 2, 3]"));
485        assert!(is_complete_json("42"));
486        assert!(is_complete_json("true"));
487        assert!(!is_complete_json(r#"{"name": "#));
488        assert!(!is_complete_json("[1, 2,"));
489    }
490
491    #[test]
492    fn test_normalize_arguments_field() {
493        // Case 1: Has parameters, no arguments
494        let obj = serde_json::json!({
495            "name": "test",
496            "parameters": {"key": "value"}
497        });
498        let normalized = normalize_arguments_field(obj);
499        assert_eq!(
500            normalized.get("arguments").unwrap(),
501            &serde_json::json!({"key": "value"})
502        );
503
504        // Case 2: Already has arguments
505        let obj = serde_json::json!({
506            "name": "test",
507            "arguments": {"key": "value"}
508        });
509        let normalized = normalize_arguments_field(obj.clone());
510        assert_eq!(normalized, obj);
511
512        // Case 3: No parameters or arguments
513        let obj = serde_json::json!({"name": "test"});
514        let normalized = normalize_arguments_field(obj.clone());
515        assert_eq!(normalized, obj);
516    }
517
518    #[test]
519    fn test_normalize_name_field() {
520        // Case 1: Has tool_name, no name (Cohere format)
521        let obj = serde_json::json!({
522            "tool_name": "search",
523            "parameters": {"query": "test"}
524        });
525        let normalized = normalize_name_field(obj);
526        assert_eq!(normalized.get("name").unwrap(), "search");
527
528        // Case 2: Already has name (standard format)
529        let obj = serde_json::json!({
530            "name": "test",
531            "arguments": {"key": "value"}
532        });
533        let normalized = normalize_name_field(obj.clone());
534        assert_eq!(normalized, obj);
535
536        // Case 3: Has both tool_name and name - name takes precedence
537        let obj = serde_json::json!({
538            "tool_name": "cohere_name",
539            "name": "standard_name",
540            "parameters": {}
541        });
542        let normalized = normalize_name_field(obj);
543        assert_eq!(normalized.get("name").unwrap(), "standard_name");
544
545        // Case 4: No name or tool_name
546        let obj = serde_json::json!({"parameters": {}});
547        let normalized = normalize_name_field(obj.clone());
548        assert!(normalized.get("name").is_none());
549    }
550
551    #[test]
552    fn test_normalize_tool_call_fields() {
553        // Case 1: Full Cohere format with tool_name and parameters
554        let obj = serde_json::json!({
555            "tool_name": "search",
556            "parameters": {"query": "rust programming"}
557        });
558        let normalized = normalize_tool_call_fields(obj);
559        assert_eq!(normalized.get("name").unwrap(), "search");
560        assert_eq!(
561            normalized.get("arguments").unwrap(),
562            &serde_json::json!({"query": "rust programming"})
563        );
564
565        // Case 2: Standard format - should remain unchanged
566        let obj = serde_json::json!({
567            "name": "test",
568            "arguments": {"key": "value"}
569        });
570        let normalized = normalize_tool_call_fields(obj.clone());
571        assert_eq!(normalized, obj);
572
573        // Case 3: Mixed format (name + parameters)
574        let obj = serde_json::json!({
575            "name": "test",
576            "parameters": {"key": "value"}
577        });
578        let normalized = normalize_tool_call_fields(obj);
579        assert_eq!(normalized.get("name").unwrap(), "test");
580        assert_eq!(
581            normalized.get("arguments").unwrap(),
582            &serde_json::json!({"key": "value"})
583        );
584    }
585}