Skip to main content

sh_layer3/tool_executor/
executor.rs

1//! # Default Tool Executor Implementation
2//!
3//! 工具执行器的默认实现。
4
5use crate::builtin_tools::{BuiltinTool, BuiltinToolRegistry};
6use crate::tool_executor::{ToolExecutor, ToolValidator};
7use crate::types::{Layer3Error, Layer3Result, ToolMeta, ToolRequest, ToolResponse};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use std::time::Instant;
13
14/// 默认工具执行器
15pub struct DefaultToolExecutor {
16    /// 内置工具注册表
17    builtin: BuiltinToolRegistry,
18    /// 执行历史(用于调试)
19    history: Arc<RwLock<VecDeque<ExecutionRecord>>>,
20    /// 最大历史记录数
21    max_history: usize,
22}
23
24/// 执行记录
25#[derive(Debug, Clone)]
26pub struct ExecutionRecord {
27    pub request: ToolRequest,
28    pub response: ToolResponse,
29    pub timestamp: chrono::DateTime<chrono::Utc>,
30    pub duration_ms: u64,
31}
32
33impl DefaultToolExecutor {
34    pub fn new() -> Self {
35        Self {
36            builtin: BuiltinToolRegistry::with_defaults(),
37            history: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
38            max_history: 1000,
39        }
40    }
41
42    /// 注册内置工具
43    pub fn register_tool(&mut self, tool: Box<dyn BuiltinTool>) {
44        self.builtin.register(tool);
45    }
46
47    /// 获取执行历史
48    pub fn history(&self) -> Vec<ExecutionRecord> {
49        self.history.read().iter().cloned().collect()
50    }
51
52    /// 清空历史
53    pub fn clear_history(&self) {
54        self.history.write().clear();
55    }
56
57    /// 记录执行
58    fn record(&self, request: &ToolRequest, response: &ToolResponse, duration_ms: u64) {
59        let mut history = self.history.write();
60        if history.len() >= self.max_history {
61            history.pop_front();
62        }
63        history.push_back(ExecutionRecord {
64            request: request.clone(),
65            response: response.clone(),
66            timestamp: chrono::Utc::now(),
67            duration_ms,
68        });
69    }
70}
71
72impl Default for DefaultToolExecutor {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[async_trait]
79impl ToolExecutor for DefaultToolExecutor {
80    async fn execute(&self, request: ToolRequest) -> Layer3Result<ToolResponse> {
81        let start = Instant::now();
82
83        // 查找工具
84        let tool = self
85            .builtin
86            .get(&request.name)
87            .ok_or_else(|| Layer3Error::ToolNotFound(request.name.clone()))?;
88
89        // 执行工具
90        let result = tool
91            .execute(request.arguments.clone())
92            .await
93            .map_err(|e| Layer3Error::ToolExecutionFailed(e.to_string()))?;
94
95        let duration_ms = start.elapsed().as_millis() as u64;
96
97        let response = ToolResponse {
98            call_id: request.call_id.clone(),
99            name: request.name.clone(),
100            content: result,
101            is_error: false,
102            duration_ms,
103        };
104
105        // 记录
106        self.record(&request, &response, duration_ms);
107
108        Ok(response)
109    }
110
111    async fn execute_batch(&self, requests: Vec<ToolRequest>) -> Layer3Result<Vec<ToolResponse>> {
112        let mut results = Vec::with_capacity(requests.len());
113        for request in requests {
114            results.push(self.execute(request).await?);
115        }
116        Ok(results)
117    }
118
119    fn is_available(&self, name: &str) -> bool {
120        self.builtin.get(name).is_some()
121    }
122
123    fn get_meta(&self, name: &str) -> Option<ToolMeta> {
124        // 需要在 BuiltinToolRegistry 中添加 get_meta 方法
125        self.builtin.get(name).map(|t| t.meta())
126    }
127
128    fn list_tools(&self) -> Vec<ToolMeta> {
129        self.builtin.list_meta()
130    }
131}
132
133/// 参数验证器
134pub struct JsonSchemaValidator {
135    schemas: HashMap<String, serde_json::Value>,
136}
137
138impl JsonSchemaValidator {
139    pub fn new() -> Self {
140        Self {
141            schemas: HashMap::new(),
142        }
143    }
144
145    pub fn register_schema(&mut self, tool_name: String, schema: serde_json::Value) {
146        self.schemas.insert(tool_name, schema);
147    }
148}
149
150impl Default for JsonSchemaValidator {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl ToolValidator for JsonSchemaValidator {
157    fn validate(&self, request: &ToolRequest) -> bool {
158        // 简化验证:检查必需字段是否存在
159        if let Some(schema) = self.schemas.get(&request.name) {
160            if let Some(required) = schema.get("required") {
161                if let Some(required_arr) = required.as_array() {
162                    for field in required_arr {
163                        if let Some(field_name) = field.as_str() {
164                            if request.arguments.get(field_name).is_none() {
165                                return false;
166                            }
167                        }
168                    }
169                }
170            }
171        }
172        true
173    }
174
175    fn validate_with_reason(&self, request: &ToolRequest) -> Result<(), String> {
176        if self.validate(request) {
177            Ok(())
178        } else {
179            Err(format!("Validation failed for tool: {}", request.name))
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[tokio::test]
189    async fn test_executor_creation() {
190        let executor = DefaultToolExecutor::new();
191        // 整改后默认注册了内置工具
192        let tools = executor.list_tools();
193        assert!(!tools.is_empty(), "Expected tools to be registered");
194        // 至少包含基础工具:read_file, write_file, bash, grep, glob
195        assert!(
196            tools.len() >= 5,
197            "Expected at least 5 basic tools, got {}",
198            tools.len()
199        );
200    }
201
202    #[test]
203    fn test_validator() {
204        let validator = JsonSchemaValidator::new();
205        let request = ToolRequest {
206            call_id: "1".to_string(),
207            name: "test".to_string(),
208            arguments: serde_json::json!({"path": "test"}),
209        };
210        assert!(validator.validate(&request));
211    }
212}