rust_agent/tools/
utils.rs1use crate::tools::Tool;
2use std::boxed::Box;
3use serde_json::Value;
4use std::collections::HashMap;
5use crate::agents::{AgentOutput, AgentAction, AgentFinish};
6
7pub fn find_matching_tool_index(tools: &[Box<dyn Tool + Send + Sync>], requested_tool: &str) -> Option<String> {
9 if let Some(tool) = tools.iter().find(|t| {
11 t.name() == requested_tool
12 }) {
13 return Some(tool.name().to_string());
14 }
15
16 let requested_lower = requested_tool.to_lowercase();
18 for tool in tools {
19 let tool_name_lower = tool.name().to_lowercase();
20 if tool_name_lower.contains(&requested_lower) || requested_lower.contains(&tool_name_lower) {
22 return Some(tool.name().to_string());
23 }
24 }
25
26 if requested_lower.contains("weather") || requested_lower.contains("天气") {
28 if let Some(tool) = tools.iter().find(|t| t.name().to_lowercase().contains("weather") || t.name().to_lowercase().contains("天气")) {
29 return Some(tool.name().to_string());
30 }
31 }
32
33 let has_calc_keywords = requested_lower.contains("calculate") || requested_lower.contains("计算") ||
36 requested_lower.contains("plus") || requested_lower.contains("minus") ||
37 requested_lower.contains("times") || requested_lower.contains("divided");
38
39 let has_math_operators = requested_lower.contains("+") || requested_lower.contains("-") ||
40 requested_lower.contains("*") || requested_lower.contains("/") ||
41 requested_lower.contains("plus") || requested_lower.contains("minus") ||
42 requested_lower.contains("times") || requested_lower.contains("divided");
43
44 if has_calc_keywords && has_math_operators {
45 if let Some(tool) = tools.iter().find(|t| t.name().to_lowercase().contains("calculate") || t.name().to_lowercase().contains("计算")) {
46 return Some(tool.name().to_string());
47 }
48 }
49
50 None
51}
52
53pub fn parse_model_output(content: &str) -> Result<AgentOutput, anyhow::Error> {
55 if let Ok(value) = serde_json::from_str::<Value>(content) {
57 if let Some(call_tool) = value.get("call_tool") {
59 if let Some(tool_name) = call_tool.get("name") {
61 let tool_name = tool_name.as_str().unwrap_or("unknown").to_string();
62
63 let parameters = call_tool.get("parameters").cloned().unwrap_or(Value::Object(serde_json::Map::new()));
65
66 let tool_input = parameters.to_string();
68
69 return Ok(AgentOutput::Action(AgentAction {
70 tool: tool_name,
71 tool_input,
72 log: "Call tool".to_string(),
73 thought: Some("Call tool based on model output".to_string()),
74 }));
75 }
76 }
77
78 if let Some(content_value) = value.get("content") {
80 let content_text = content_value.as_str().unwrap_or("").to_string();
81 let mut return_values = HashMap::new();
82 return_values.insert("answer".to_string(), content_text);
83
84 return Ok(AgentOutput::Finish(AgentFinish {
85 return_values,
86 }));
87 }
88 }
89
90 Err(anyhow::anyhow!("Failed to parse model output"))
92}