Skip to main content

tool_parser/parsers/
helpers.rs

1use std::collections::HashMap;
2
3use openai_protocol::common::Tool;
4use serde::de::{Deserialize, IgnoredAny};
5use serde_json::{de::Deserializer, Value};
6
7use crate::{
8    errors::{ParserError, ParserResult},
9    types::{StreamingParseResult, ToolCallItem},
10};
11
12/// Get a mapping of tool names to their indices
13pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
14    tools
15        .iter()
16        .enumerate()
17        .map(|(i, tool)| (tool.function.name.clone(), i))
18        .collect()
19}
20
21/// Find the common prefix of two strings
22/// Used for incremental argument streaming when partial JSON returns different intermediate states
23pub fn find_common_prefix(s1: &str, s2: &str) -> String {
24    s1.chars()
25        .zip(s2.chars())
26        .take_while(|(c1, c2)| c1 == c2)
27        .map(|(c1, _)| c1)
28        .collect()
29}
30
31/// Get unstreamed tool call arguments
32/// Returns tool call items for arguments that have been parsed but not yet streamed
33/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
34pub fn get_unstreamed_args(
35    prev_tool_call_arr: &[Value],
36    streamed_args_for_tool: &[String],
37) -> Option<Vec<ToolCallItem>> {
38    // Check if we have tool calls being tracked
39    if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
40        return None;
41    }
42
43    // Get the last tool call that was being processed
44    let tool_index = prev_tool_call_arr.len() - 1;
45    if tool_index >= streamed_args_for_tool.len() {
46        return None;
47    }
48
49    // Get expected vs actual arguments
50    let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
51    let expected_str = serde_json::to_string(expected_args).ok()?;
52    let actual_str = &streamed_args_for_tool[tool_index];
53
54    // Check if there are remaining arguments to send
55    let remaining = if expected_str.starts_with(actual_str) {
56        &expected_str[actual_str.len()..]
57    } else {
58        return None;
59    };
60
61    if remaining.is_empty() {
62        return None;
63    }
64
65    // Return the remaining arguments as a ToolCallItem
66    Some(vec![ToolCallItem {
67        tool_index,
68        name: None, // No name for argument deltas
69        parameters: remaining.to_string(),
70    }])
71}
72
73/// Check if a buffer ends with a partial occurrence of a token
74/// Returns Some(length) if there's a partial match, None otherwise
75pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
76    if buffer.is_empty() || token.is_empty() {
77        return None;
78    }
79
80    (1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
81}
82
83/// Reset state for the current tool being parsed (used when skipping invalid tools).
84/// This preserves the parser's overall state (current_tool_id, prev_tool_call_arr)
85/// but clears the state specific to the current incomplete tool.
86pub fn reset_current_tool_state(
87    buffer: &mut String,
88    current_tool_name_sent: &mut bool,
89    streamed_args_for_tool: &mut Vec<String>,
90    prev_tool_call_arr: &[Value],
91) {
92    buffer.clear();
93    *current_tool_name_sent = false;
94
95    // Only pop if we added an entry for the current (invalid) tool
96    // streamed_args_for_tool should match prev_tool_call_arr length for completed tools
97    if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
98        streamed_args_for_tool.pop();
99    }
100}
101
102/// Reset the entire parser state (used at the start of a new request).
103/// Clears all accumulated tool calls and resets all state to initial values.
104pub fn reset_parser_state(
105    buffer: &mut String,
106    prev_tool_call_arr: &mut Vec<Value>,
107    current_tool_id: &mut i32,
108    current_tool_name_sent: &mut bool,
109    streamed_args_for_tool: &mut Vec<String>,
110) {
111    buffer.clear();
112    prev_tool_call_arr.clear();
113    *current_tool_id = -1;
114    *current_tool_name_sent = false;
115    streamed_args_for_tool.clear();
116}
117
118/// Ensure arrays have capacity for the given tool ID
119pub fn ensure_capacity(
120    current_tool_id: i32,
121    prev_tool_call_arr: &mut Vec<Value>,
122    streamed_args_for_tool: &mut Vec<String>,
123) {
124    if current_tool_id < 0 {
125        return;
126    }
127    let needed = (current_tool_id + 1) as usize;
128
129    if prev_tool_call_arr.len() < needed {
130        prev_tool_call_arr.resize_with(needed, || Value::Null);
131    }
132    if streamed_args_for_tool.len() < needed {
133        streamed_args_for_tool.resize_with(needed, String::new);
134    }
135}
136
137/// Check if a string contains complete, valid JSON
138pub fn is_complete_json(input: &str) -> bool {
139    let mut de = Deserializer::from_str(input);
140    IgnoredAny::deserialize(&mut de).is_ok() && de.end().is_ok()
141}
142
143/// Normalize the arguments/parameters field in a tool call object.
144/// If the object has "parameters" but not "arguments", copy parameters to arguments.
145///
146/// # Background
147/// Different LLM formats use different field names:
148/// - Llama and JSON parsers use "parameters" (correct per JSON Schema spec)
149/// - Mistral and Qwen use "arguments"
150///
151/// This function normalizes to "arguments" for consistent downstream processing.
152pub fn normalize_arguments_field(mut obj: Value) -> Value {
153    if obj.get("arguments").is_none() {
154        if let Some(params) = obj.get("parameters").cloned() {
155            if let Value::Object(ref mut map) = obj {
156                map.insert("arguments".to_string(), params);
157            }
158        }
159    }
160    obj
161}
162
163/// Normalize the name/tool_name field in a tool call object.
164/// If the object has "tool_name" but not "name", copy tool_name to name.
165///
166/// # Background
167/// Cohere models use "tool_name" instead of "name":
168/// - Standard format uses "name"
169/// - Cohere uses "tool_name"
170///
171/// This function normalizes to "name" for consistent downstream processing.
172pub fn normalize_name_field(mut obj: Value) -> Value {
173    if obj.get("name").is_none() {
174        if let Some(tool_name) = obj.get("tool_name").cloned() {
175            if let Value::Object(ref mut map) = obj {
176                map.insert("name".to_string(), tool_name);
177            }
178        }
179    }
180    obj
181}
182
183/// Normalize all tool call fields (both name and arguments).
184/// Combines normalize_name_field and normalize_arguments_field.
185///
186/// This handles formats like Cohere that use both "tool_name" and "parameters"
187/// instead of the standard "name" and "arguments".
188pub fn normalize_tool_call_fields(obj: Value) -> Value {
189    let obj = normalize_name_field(obj);
190    normalize_arguments_field(obj)
191}
192
193/// Handle the entire JSON tool call streaming process for JSON-based parsers.
194///
195/// This unified function handles all aspects of streaming tool calls:
196/// - Parsing partial JSON from the buffer
197/// - Validating tool names against available tools
198/// - Streaming tool names (Case 1)
199/// - Streaming tool arguments (Case 2)
200/// - Managing parser state and buffer updates
201///
202/// Used by JSON, Llama, Mistral, and Qwen parsers.
203///
204/// # Parameters
205/// - `current_text`: The current buffered text being parsed
206/// - `start_idx`: Start index of JSON content in current_text
207/// - `partial_json`: Mutable reference to partial JSON parser
208/// - `tool_indices`: Map of valid tool names to their indices
209/// - `buffer`: Mutable parser buffer
210/// - `current_tool_id`: Mutable current tool index (-1 means no active tool)
211/// - `current_tool_name_sent`: Mutable flag for whether current tool's name was sent
212/// - `streamed_args_for_tool`: Mutable accumulator of streamed arguments per tool
213/// - `prev_tool_call_arr`: Mutable array of previous tool call states
214///
215/// # Returns
216/// - `Ok(StreamingParseResult)` with any tool call items to stream
217/// - `Err(ParserError)` if JSON parsing or serialization fails
218#[expect(clippy::too_many_arguments)]
219pub(crate) fn handle_json_tool_streaming(
220    current_text: &str,
221    start_idx: usize,
222    partial_json: &mut crate::partial_json::PartialJson,
223    tool_indices: &HashMap<String, usize>,
224    buffer: &mut String,
225    current_tool_id: &mut i32,
226    current_tool_name_sent: &mut bool,
227    streamed_args_for_tool: &mut Vec<String>,
228    prev_tool_call_arr: &mut Vec<Value>,
229) -> ParserResult<StreamingParseResult> {
230    // Check if we have content to parse
231    if start_idx >= current_text.len() {
232        return Ok(StreamingParseResult::default());
233    }
234
235    // Extract JSON string from current position
236    let json_str = &current_text[start_idx..];
237
238    // When current_tool_name_sent is false, don't allow partial strings to avoid
239    // parsing incomplete tool names as empty strings
240    let allow_partial_strings = *current_tool_name_sent;
241
242    // Parse partial JSON
243    let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
244        Ok(result) => result,
245        Err(_) => {
246            return Ok(StreamingParseResult::default());
247        }
248    };
249
250    // Check if JSON is complete - validate only the parsed portion
251    // Ensure end_idx is on a valid UTF-8 character boundary
252    let safe_end_idx = if json_str.is_char_boundary(end_idx) {
253        end_idx
254    } else {
255        // Find the nearest valid character boundary before end_idx
256        (0..end_idx)
257            .rev()
258            .find(|&i| json_str.is_char_boundary(i))
259            .unwrap_or(0)
260    };
261    let is_complete = is_complete_json(&json_str[..safe_end_idx]);
262
263    // Normalize all tool call fields first (handles tool_name -> name, parameters -> arguments)
264    // This must happen before validation since different LLMs use different field names
265    let current_tool_call = normalize_tool_call_fields(obj);
266
267    // Validate tool name if present
268    if let Some(name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
269        if !tool_indices.contains_key(name) {
270            // Invalid tool name - skip this tool, preserve indexing for next tool
271            tracing::debug!("Invalid tool name '{}' - skipping", name);
272            reset_current_tool_state(
273                buffer,
274                current_tool_name_sent,
275                streamed_args_for_tool,
276                prev_tool_call_arr,
277            );
278            return Ok(StreamingParseResult::default());
279        }
280    }
281
282    let mut result = StreamingParseResult::default();
283
284    // Case 1: Handle tool name streaming
285    if !*current_tool_name_sent {
286        if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
287            if tool_indices.contains_key(function_name) {
288                // Initialize if first tool
289                if *current_tool_id == -1 {
290                    *current_tool_id = 0;
291                    streamed_args_for_tool.push(String::new());
292                } else if *current_tool_id as usize >= streamed_args_for_tool.len() {
293                    // Ensure capacity for subsequent tools
294                    ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
295                }
296
297                // Send tool name with empty parameters
298                *current_tool_name_sent = true;
299                result.calls.push(ToolCallItem {
300                    tool_index: *current_tool_id as usize,
301                    name: Some(function_name.to_string()),
302                    parameters: String::new(),
303                });
304            }
305        }
306    }
307    // Case 2: Handle streaming arguments
308    else if let Some(cur_arguments) = current_tool_call.get("arguments") {
309        let tool_id = *current_tool_id as usize;
310        let sent = streamed_args_for_tool
311            .get(tool_id)
312            .map(|s| s.len())
313            .unwrap_or(0);
314        let cur_args_json = serde_json::to_string(cur_arguments)
315            .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
316
317        // Get prev_arguments (matches Python's structure)
318        let prev_arguments = if tool_id < prev_tool_call_arr.len() {
319            prev_tool_call_arr[tool_id].get("arguments")
320        } else {
321            None
322        };
323
324        // Calculate diff: everything after we've already sent
325        let mut argument_diff = None;
326
327        if is_complete {
328            // Python: argument_diff = cur_args_json[sent:]
329            // Rust needs bounds check (Python returns "" automatically)
330            argument_diff = if sent < cur_args_json.len() {
331                Some(cur_args_json[sent..].to_string())
332            } else {
333                Some(String::new())
334            };
335        } else if let Some(prev_args) = prev_arguments {
336            let prev_args_json = serde_json::to_string(prev_args)
337                .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
338
339            if cur_args_json != prev_args_json {
340                let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
341                argument_diff = if sent < prefix.len() {
342                    Some(prefix[sent..].to_string())
343                } else {
344                    Some(String::new())
345                };
346            }
347        }
348
349        // Send diff if present
350        if let Some(diff) = argument_diff {
351            if !diff.is_empty() {
352                if tool_id < streamed_args_for_tool.len() {
353                    streamed_args_for_tool[tool_id].push_str(&diff);
354                }
355                result.calls.push(ToolCallItem {
356                    tool_index: tool_id,
357                    name: None,
358                    parameters: diff,
359                });
360            }
361        }
362
363        // Update prev_tool_call_arr with current state
364        if *current_tool_id >= 0 {
365            ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
366
367            if tool_id < prev_tool_call_arr.len() {
368                prev_tool_call_arr[tool_id] = current_tool_call;
369            }
370        }
371
372        // If complete, advance to next tool
373        if is_complete {
374            *buffer = current_text[start_idx + end_idx..].to_string();
375            *current_tool_name_sent = false;
376            *current_tool_id += 1;
377        }
378    }
379
380    Ok(result)
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_ends_with_partial_token() {
389        assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
390        assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
391        assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
392        assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
393        assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
394    }
395
396    #[test]
397    fn test_reset_current_tool_state() {
398        let mut buffer = String::from("partial json");
399        let mut current_tool_name_sent = true;
400        let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
401        let prev_tools = vec![serde_json::json!({"name": "tool0"})];
402
403        reset_current_tool_state(
404            &mut buffer,
405            &mut current_tool_name_sent,
406            &mut streamed_args,
407            &prev_tools,
408        );
409
410        assert_eq!(buffer, "");
411        assert!(!current_tool_name_sent);
412        assert_eq!(streamed_args.len(), 1); // Popped the partial tool1 args
413        assert_eq!(streamed_args[0], "tool0_args");
414    }
415
416    #[test]
417    fn test_reset_current_tool_state_no_pop_when_synced() {
418        let mut buffer = String::from("partial json");
419        let mut current_tool_name_sent = true;
420        let mut streamed_args = vec!["tool0_args".to_string()];
421        let prev_tools = vec![serde_json::json!({"name": "tool0"})];
422
423        reset_current_tool_state(
424            &mut buffer,
425            &mut current_tool_name_sent,
426            &mut streamed_args,
427            &prev_tools,
428        );
429
430        assert_eq!(buffer, "");
431        assert!(!current_tool_name_sent);
432        assert_eq!(streamed_args.len(), 1); // No pop, lengths matched
433    }
434
435    #[test]
436    fn test_reset_parser_state() {
437        let mut buffer = String::from("some buffer");
438        let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
439        let mut current_tool_id = 5;
440        let mut current_tool_name_sent = true;
441        let mut streamed_args = vec!["args".to_string()];
442
443        reset_parser_state(
444            &mut buffer,
445            &mut prev_tools,
446            &mut current_tool_id,
447            &mut current_tool_name_sent,
448            &mut streamed_args,
449        );
450
451        assert_eq!(buffer, "");
452        assert_eq!(prev_tools.len(), 0);
453        assert_eq!(current_tool_id, -1);
454        assert!(!current_tool_name_sent);
455        assert_eq!(streamed_args.len(), 0);
456    }
457
458    #[test]
459    fn test_ensure_capacity() {
460        let mut prev_tools = vec![];
461        let mut streamed_args = vec![];
462
463        ensure_capacity(2, &mut prev_tools, &mut streamed_args);
464
465        assert_eq!(prev_tools.len(), 3);
466        assert_eq!(streamed_args.len(), 3);
467        assert_eq!(prev_tools[0], Value::Null);
468        assert_eq!(streamed_args[0], "");
469    }
470
471    #[test]
472    fn test_ensure_capacity_negative_id() {
473        let mut prev_tools = vec![];
474        let mut streamed_args = vec![];
475
476        ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
477
478        // Should not resize for negative ID
479        assert_eq!(prev_tools.len(), 0);
480        assert_eq!(streamed_args.len(), 0);
481    }
482
483    #[test]
484    fn test_is_complete_json() {
485        assert!(is_complete_json(r#"{"name": "test"}"#));
486        assert!(is_complete_json("[1, 2, 3]"));
487        assert!(is_complete_json("42"));
488        assert!(is_complete_json("true"));
489        assert!(!is_complete_json(r#"{"name": "#));
490        assert!(!is_complete_json("[1, 2,"));
491    }
492
493    #[test]
494    fn test_normalize_arguments_field() {
495        // Case 1: Has parameters, no arguments
496        let obj = serde_json::json!({
497            "name": "test",
498            "parameters": {"key": "value"}
499        });
500        let normalized = normalize_arguments_field(obj);
501        assert_eq!(
502            normalized.get("arguments").unwrap(),
503            &serde_json::json!({"key": "value"})
504        );
505
506        // Case 2: Already has arguments
507        let obj = serde_json::json!({
508            "name": "test",
509            "arguments": {"key": "value"}
510        });
511        let normalized = normalize_arguments_field(obj.clone());
512        assert_eq!(normalized, obj);
513
514        // Case 3: No parameters or arguments
515        let obj = serde_json::json!({"name": "test"});
516        let normalized = normalize_arguments_field(obj.clone());
517        assert_eq!(normalized, obj);
518    }
519
520    #[test]
521    fn test_normalize_name_field() {
522        // Case 1: Has tool_name, no name (Cohere format)
523        let obj = serde_json::json!({
524            "tool_name": "search",
525            "parameters": {"query": "test"}
526        });
527        let normalized = normalize_name_field(obj);
528        assert_eq!(normalized.get("name").unwrap(), "search");
529
530        // Case 2: Already has name (standard format)
531        let obj = serde_json::json!({
532            "name": "test",
533            "arguments": {"key": "value"}
534        });
535        let normalized = normalize_name_field(obj.clone());
536        assert_eq!(normalized, obj);
537
538        // Case 3: Has both tool_name and name - name takes precedence
539        let obj = serde_json::json!({
540            "tool_name": "cohere_name",
541            "name": "standard_name",
542            "parameters": {}
543        });
544        let normalized = normalize_name_field(obj);
545        assert_eq!(normalized.get("name").unwrap(), "standard_name");
546
547        // Case 4: No name or tool_name
548        let obj = serde_json::json!({"parameters": {}});
549        let normalized = normalize_name_field(obj.clone());
550        assert!(normalized.get("name").is_none());
551    }
552
553    #[test]
554    fn test_normalize_tool_call_fields() {
555        // Case 1: Full Cohere format with tool_name and parameters
556        let obj = serde_json::json!({
557            "tool_name": "search",
558            "parameters": {"query": "rust programming"}
559        });
560        let normalized = normalize_tool_call_fields(obj);
561        assert_eq!(normalized.get("name").unwrap(), "search");
562        assert_eq!(
563            normalized.get("arguments").unwrap(),
564            &serde_json::json!({"query": "rust programming"})
565        );
566
567        // Case 2: Standard format - should remain unchanged
568        let obj = serde_json::json!({
569            "name": "test",
570            "arguments": {"key": "value"}
571        });
572        let normalized = normalize_tool_call_fields(obj.clone());
573        assert_eq!(normalized, obj);
574
575        // Case 3: Mixed format (name + parameters)
576        let obj = serde_json::json!({
577            "name": "test",
578            "parameters": {"key": "value"}
579        });
580        let normalized = normalize_tool_call_fields(obj);
581        assert_eq!(normalized.get("name").unwrap(), "test");
582        assert_eq!(
583            normalized.get("arguments").unwrap(),
584            &serde_json::json!({"key": "value"})
585        );
586    }
587}