1use std::collections::HashMap;
2
3use openai_protocol::common::Tool;
4use serde_json::Value;
5
6use crate::{
7 errors::{ParserError, ParserResult},
8 types::{StreamingParseResult, ToolCallItem},
9};
10
11pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
13 tools
14 .iter()
15 .enumerate()
16 .map(|(i, tool)| (tool.function.name.clone(), i))
17 .collect()
18}
19
20pub fn find_common_prefix(s1: &str, s2: &str) -> String {
23 s1.chars()
24 .zip(s2.chars())
25 .take_while(|(c1, c2)| c1 == c2)
26 .map(|(c1, _)| c1)
27 .collect()
28}
29
30pub fn get_unstreamed_args(
34 prev_tool_call_arr: &[Value],
35 streamed_args_for_tool: &[String],
36) -> Option<Vec<ToolCallItem>> {
37 if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
39 return None;
40 }
41
42 let tool_index = prev_tool_call_arr.len() - 1;
44 if tool_index >= streamed_args_for_tool.len() {
45 return None;
46 }
47
48 let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
50 let expected_str = serde_json::to_string(expected_args).ok()?;
51 let actual_str = &streamed_args_for_tool[tool_index];
52
53 let remaining = if expected_str.starts_with(actual_str) {
55 &expected_str[actual_str.len()..]
56 } else {
57 return None;
58 };
59
60 if remaining.is_empty() {
61 return None;
62 }
63
64 Some(vec![ToolCallItem {
66 tool_index,
67 name: None, parameters: remaining.to_string(),
69 }])
70}
71
72pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
75 if buffer.is_empty() || token.is_empty() {
76 return None;
77 }
78
79 (1..token.len()).find(|&i| buffer.ends_with(&token[..i]))
80}
81
82pub fn reset_current_tool_state(
86 buffer: &mut String,
87 current_tool_name_sent: &mut bool,
88 streamed_args_for_tool: &mut Vec<String>,
89 prev_tool_call_arr: &[Value],
90) {
91 buffer.clear();
92 *current_tool_name_sent = false;
93
94 if streamed_args_for_tool.len() > prev_tool_call_arr.len() {
97 streamed_args_for_tool.pop();
98 }
99}
100
101pub fn reset_parser_state(
104 buffer: &mut String,
105 prev_tool_call_arr: &mut Vec<Value>,
106 current_tool_id: &mut i32,
107 current_tool_name_sent: &mut bool,
108 streamed_args_for_tool: &mut Vec<String>,
109) {
110 buffer.clear();
111 prev_tool_call_arr.clear();
112 *current_tool_id = -1;
113 *current_tool_name_sent = false;
114 streamed_args_for_tool.clear();
115}
116
117pub fn ensure_capacity(
119 current_tool_id: i32,
120 prev_tool_call_arr: &mut Vec<Value>,
121 streamed_args_for_tool: &mut Vec<String>,
122) {
123 if current_tool_id < 0 {
124 return;
125 }
126 let needed = (current_tool_id + 1) as usize;
127
128 if prev_tool_call_arr.len() < needed {
129 prev_tool_call_arr.resize_with(needed, || Value::Null);
130 }
131 if streamed_args_for_tool.len() < needed {
132 streamed_args_for_tool.resize_with(needed, String::new);
133 }
134}
135
136pub fn is_complete_json(input: &str) -> bool {
138 serde_json::from_str::<Value>(input).is_ok()
139}
140
141pub fn normalize_arguments_field(mut obj: Value) -> Value {
151 if obj.get("arguments").is_none() {
152 if let Some(params) = obj.get("parameters").cloned() {
153 if let Value::Object(ref mut map) = obj {
154 map.insert("arguments".to_string(), params);
155 }
156 }
157 }
158 obj
159}
160
161pub fn normalize_name_field(mut obj: Value) -> Value {
171 if obj.get("name").is_none() {
172 if let Some(tool_name) = obj.get("tool_name").cloned() {
173 if let Value::Object(ref mut map) = obj {
174 map.insert("name".to_string(), tool_name);
175 }
176 }
177 }
178 obj
179}
180
181pub fn normalize_tool_call_fields(obj: Value) -> Value {
187 let obj = normalize_name_field(obj);
188 normalize_arguments_field(obj)
189}
190
191#[allow(clippy::too_many_arguments)]
217pub(crate) fn handle_json_tool_streaming(
218 current_text: &str,
219 start_idx: usize,
220 partial_json: &mut crate::partial_json::PartialJson,
221 tool_indices: &HashMap<String, usize>,
222 buffer: &mut String,
223 current_tool_id: &mut i32,
224 current_tool_name_sent: &mut bool,
225 streamed_args_for_tool: &mut Vec<String>,
226 prev_tool_call_arr: &mut Vec<Value>,
227) -> ParserResult<StreamingParseResult> {
228 if start_idx >= current_text.len() {
230 return Ok(StreamingParseResult::default());
231 }
232
233 let json_str = ¤t_text[start_idx..];
235
236 let allow_partial_strings = *current_tool_name_sent;
239
240 let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
242 Ok(result) => result,
243 Err(_) => {
244 return Ok(StreamingParseResult::default());
245 }
246 };
247
248 let safe_end_idx = if json_str.is_char_boundary(end_idx) {
251 end_idx
252 } else {
253 (0..end_idx)
255 .rev()
256 .find(|&i| json_str.is_char_boundary(i))
257 .unwrap_or(0)
258 };
259 let is_complete = serde_json::from_str::<Value>(&json_str[..safe_end_idx]).is_ok();
260
261 let current_tool_call = normalize_tool_call_fields(obj);
264
265 if let Some(name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
267 if !tool_indices.contains_key(name) {
268 tracing::debug!("Invalid tool name '{}' - skipping", name);
270 reset_current_tool_state(
271 buffer,
272 current_tool_name_sent,
273 streamed_args_for_tool,
274 prev_tool_call_arr,
275 );
276 return Ok(StreamingParseResult::default());
277 }
278 }
279
280 let mut result = StreamingParseResult::default();
281
282 if !*current_tool_name_sent {
284 if let Some(function_name) = current_tool_call.get("name").and_then(|v| v.as_str()) {
285 if tool_indices.contains_key(function_name) {
286 if *current_tool_id == -1 {
288 *current_tool_id = 0;
289 streamed_args_for_tool.push(String::new());
290 } else if *current_tool_id as usize >= streamed_args_for_tool.len() {
291 ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
293 }
294
295 *current_tool_name_sent = true;
297 result.calls.push(ToolCallItem {
298 tool_index: *current_tool_id as usize,
299 name: Some(function_name.to_string()),
300 parameters: String::new(),
301 });
302 }
303 }
304 }
305 else if let Some(cur_arguments) = current_tool_call.get("arguments") {
307 let tool_id = *current_tool_id as usize;
308 let sent = streamed_args_for_tool
309 .get(tool_id)
310 .map(|s| s.len())
311 .unwrap_or(0);
312 let cur_args_json = serde_json::to_string(cur_arguments)
313 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
314
315 let prev_arguments = if tool_id < prev_tool_call_arr.len() {
317 prev_tool_call_arr[tool_id].get("arguments")
318 } else {
319 None
320 };
321
322 let mut argument_diff = None;
324
325 if is_complete {
326 argument_diff = if sent < cur_args_json.len() {
329 Some(cur_args_json[sent..].to_string())
330 } else {
331 Some(String::new())
332 };
333 } else if let Some(prev_args) = prev_arguments {
334 let prev_args_json = serde_json::to_string(prev_args)
335 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
336
337 if cur_args_json != prev_args_json {
338 let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
339 argument_diff = if sent < prefix.len() {
340 Some(prefix[sent..].to_string())
341 } else {
342 Some(String::new())
343 };
344 }
345 }
346
347 if let Some(diff) = argument_diff {
349 if !diff.is_empty() {
350 if tool_id < streamed_args_for_tool.len() {
351 streamed_args_for_tool[tool_id].push_str(&diff);
352 }
353 result.calls.push(ToolCallItem {
354 tool_index: tool_id,
355 name: None,
356 parameters: diff,
357 });
358 }
359 }
360
361 if *current_tool_id >= 0 {
363 ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
364
365 if tool_id < prev_tool_call_arr.len() {
366 prev_tool_call_arr[tool_id] = current_tool_call;
367 }
368 }
369
370 if is_complete {
372 *buffer = current_text[start_idx + end_idx..].to_string();
373 *current_tool_name_sent = false;
374 *current_tool_id += 1;
375 }
376 }
377
378 Ok(result)
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_ends_with_partial_token() {
387 assert!(ends_with_partial_token("hello <|py", "<|python_tag|>").is_some());
388 assert!(ends_with_partial_token("hello <|python_tag", "<|python_tag|>").is_some());
389 assert!(ends_with_partial_token("hello <|python_tag|>", "<|python_tag|>").is_none());
390 assert!(ends_with_partial_token("", "<|python_tag|>").is_none());
391 assert!(ends_with_partial_token("hello world", "<|python_tag|>").is_none());
392 }
393
394 #[test]
395 fn test_reset_current_tool_state() {
396 let mut buffer = String::from("partial json");
397 let mut current_tool_name_sent = true;
398 let mut streamed_args = vec!["tool0_args".to_string(), "tool1_partial".to_string()];
399 let prev_tools = vec![serde_json::json!({"name": "tool0"})];
400
401 reset_current_tool_state(
402 &mut buffer,
403 &mut current_tool_name_sent,
404 &mut streamed_args,
405 &prev_tools,
406 );
407
408 assert_eq!(buffer, "");
409 assert!(!current_tool_name_sent);
410 assert_eq!(streamed_args.len(), 1); assert_eq!(streamed_args[0], "tool0_args");
412 }
413
414 #[test]
415 fn test_reset_current_tool_state_no_pop_when_synced() {
416 let mut buffer = String::from("partial json");
417 let mut current_tool_name_sent = true;
418 let mut streamed_args = vec!["tool0_args".to_string()];
419 let prev_tools = vec![serde_json::json!({"name": "tool0"})];
420
421 reset_current_tool_state(
422 &mut buffer,
423 &mut current_tool_name_sent,
424 &mut streamed_args,
425 &prev_tools,
426 );
427
428 assert_eq!(buffer, "");
429 assert!(!current_tool_name_sent);
430 assert_eq!(streamed_args.len(), 1); }
432
433 #[test]
434 fn test_reset_parser_state() {
435 let mut buffer = String::from("some buffer");
436 let mut prev_tools = vec![serde_json::json!({"name": "tool0"})];
437 let mut current_tool_id = 5;
438 let mut current_tool_name_sent = true;
439 let mut streamed_args = vec!["args".to_string()];
440
441 reset_parser_state(
442 &mut buffer,
443 &mut prev_tools,
444 &mut current_tool_id,
445 &mut current_tool_name_sent,
446 &mut streamed_args,
447 );
448
449 assert_eq!(buffer, "");
450 assert_eq!(prev_tools.len(), 0);
451 assert_eq!(current_tool_id, -1);
452 assert!(!current_tool_name_sent);
453 assert_eq!(streamed_args.len(), 0);
454 }
455
456 #[test]
457 fn test_ensure_capacity() {
458 let mut prev_tools = vec![];
459 let mut streamed_args = vec![];
460
461 ensure_capacity(2, &mut prev_tools, &mut streamed_args);
462
463 assert_eq!(prev_tools.len(), 3);
464 assert_eq!(streamed_args.len(), 3);
465 assert_eq!(prev_tools[0], Value::Null);
466 assert_eq!(streamed_args[0], "");
467 }
468
469 #[test]
470 fn test_ensure_capacity_negative_id() {
471 let mut prev_tools = vec![];
472 let mut streamed_args = vec![];
473
474 ensure_capacity(-1, &mut prev_tools, &mut streamed_args);
475
476 assert_eq!(prev_tools.len(), 0);
478 assert_eq!(streamed_args.len(), 0);
479 }
480
481 #[test]
482 fn test_is_complete_json() {
483 assert!(is_complete_json(r#"{"name": "test"}"#));
484 assert!(is_complete_json("[1, 2, 3]"));
485 assert!(is_complete_json("42"));
486 assert!(is_complete_json("true"));
487 assert!(!is_complete_json(r#"{"name": "#));
488 assert!(!is_complete_json("[1, 2,"));
489 }
490
491 #[test]
492 fn test_normalize_arguments_field() {
493 let obj = serde_json::json!({
495 "name": "test",
496 "parameters": {"key": "value"}
497 });
498 let normalized = normalize_arguments_field(obj);
499 assert_eq!(
500 normalized.get("arguments").unwrap(),
501 &serde_json::json!({"key": "value"})
502 );
503
504 let obj = serde_json::json!({
506 "name": "test",
507 "arguments": {"key": "value"}
508 });
509 let normalized = normalize_arguments_field(obj.clone());
510 assert_eq!(normalized, obj);
511
512 let obj = serde_json::json!({"name": "test"});
514 let normalized = normalize_arguments_field(obj.clone());
515 assert_eq!(normalized, obj);
516 }
517
518 #[test]
519 fn test_normalize_name_field() {
520 let obj = serde_json::json!({
522 "tool_name": "search",
523 "parameters": {"query": "test"}
524 });
525 let normalized = normalize_name_field(obj);
526 assert_eq!(normalized.get("name").unwrap(), "search");
527
528 let obj = serde_json::json!({
530 "name": "test",
531 "arguments": {"key": "value"}
532 });
533 let normalized = normalize_name_field(obj.clone());
534 assert_eq!(normalized, obj);
535
536 let obj = serde_json::json!({
538 "tool_name": "cohere_name",
539 "name": "standard_name",
540 "parameters": {}
541 });
542 let normalized = normalize_name_field(obj);
543 assert_eq!(normalized.get("name").unwrap(), "standard_name");
544
545 let obj = serde_json::json!({"parameters": {}});
547 let normalized = normalize_name_field(obj.clone());
548 assert!(normalized.get("name").is_none());
549 }
550
551 #[test]
552 fn test_normalize_tool_call_fields() {
553 let obj = serde_json::json!({
555 "tool_name": "search",
556 "parameters": {"query": "rust programming"}
557 });
558 let normalized = normalize_tool_call_fields(obj);
559 assert_eq!(normalized.get("name").unwrap(), "search");
560 assert_eq!(
561 normalized.get("arguments").unwrap(),
562 &serde_json::json!({"query": "rust programming"})
563 );
564
565 let obj = serde_json::json!({
567 "name": "test",
568 "arguments": {"key": "value"}
569 });
570 let normalized = normalize_tool_call_fields(obj.clone());
571 assert_eq!(normalized, obj);
572
573 let obj = serde_json::json!({
575 "name": "test",
576 "parameters": {"key": "value"}
577 });
578 let normalized = normalize_tool_call_fields(obj);
579 assert_eq!(normalized.get("name").unwrap(), "test");
580 assert_eq!(
581 normalized.get("arguments").unwrap(),
582 &serde_json::json!({"key": "value"})
583 );
584 }
585}