sh_layer3/tool_executor/
executor.rs1use crate::builtin_tools::{BuiltinTool, BuiltinToolRegistry};
6use crate::tool_executor::{ToolExecutor, ToolValidator};
7use crate::types::{Layer3Error, Layer3Result, ToolMeta, ToolRequest, ToolResponse};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use std::time::Instant;
13
14pub struct DefaultToolExecutor {
16 builtin: BuiltinToolRegistry,
18 history: Arc<RwLock<VecDeque<ExecutionRecord>>>,
20 max_history: usize,
22}
23
24#[derive(Debug, Clone)]
26pub struct ExecutionRecord {
27 pub request: ToolRequest,
28 pub response: ToolResponse,
29 pub timestamp: chrono::DateTime<chrono::Utc>,
30 pub duration_ms: u64,
31}
32
33impl DefaultToolExecutor {
34 pub fn new() -> Self {
35 Self {
36 builtin: BuiltinToolRegistry::with_defaults(),
37 history: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
38 max_history: 1000,
39 }
40 }
41
42 pub fn register_tool(&mut self, tool: Box<dyn BuiltinTool>) {
44 self.builtin.register(tool);
45 }
46
47 pub fn history(&self) -> Vec<ExecutionRecord> {
49 self.history.read().iter().cloned().collect()
50 }
51
52 pub fn clear_history(&self) {
54 self.history.write().clear();
55 }
56
57 fn record(&self, request: &ToolRequest, response: &ToolResponse, duration_ms: u64) {
59 let mut history = self.history.write();
60 if history.len() >= self.max_history {
61 history.pop_front();
62 }
63 history.push_back(ExecutionRecord {
64 request: request.clone(),
65 response: response.clone(),
66 timestamp: chrono::Utc::now(),
67 duration_ms,
68 });
69 }
70}
71
72impl Default for DefaultToolExecutor {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[async_trait]
79impl ToolExecutor for DefaultToolExecutor {
80 async fn execute(&self, request: ToolRequest) -> Layer3Result<ToolResponse> {
81 let start = Instant::now();
82
83 let tool = self
85 .builtin
86 .get(&request.name)
87 .ok_or_else(|| Layer3Error::ToolNotFound(request.name.clone()))?;
88
89 let result = tool
91 .execute(request.arguments.clone())
92 .await
93 .map_err(|e| Layer3Error::ToolExecutionFailed(e.to_string()))?;
94
95 let duration_ms = start.elapsed().as_millis() as u64;
96
97 let response = ToolResponse {
98 call_id: request.call_id.clone(),
99 name: request.name.clone(),
100 content: result,
101 is_error: false,
102 duration_ms,
103 };
104
105 self.record(&request, &response, duration_ms);
107
108 Ok(response)
109 }
110
111 async fn execute_batch(&self, requests: Vec<ToolRequest>) -> Layer3Result<Vec<ToolResponse>> {
112 let mut results = Vec::with_capacity(requests.len());
113 for request in requests {
114 results.push(self.execute(request).await?);
115 }
116 Ok(results)
117 }
118
119 fn is_available(&self, name: &str) -> bool {
120 self.builtin.get(name).is_some()
121 }
122
123 fn get_meta(&self, name: &str) -> Option<ToolMeta> {
124 self.builtin.get(name).map(|t| t.meta())
126 }
127
128 fn list_tools(&self) -> Vec<ToolMeta> {
129 self.builtin.list_meta()
130 }
131}
132
133pub struct JsonSchemaValidator {
135 schemas: HashMap<String, serde_json::Value>,
136}
137
138impl JsonSchemaValidator {
139 pub fn new() -> Self {
140 Self {
141 schemas: HashMap::new(),
142 }
143 }
144
145 pub fn register_schema(&mut self, tool_name: String, schema: serde_json::Value) {
146 self.schemas.insert(tool_name, schema);
147 }
148}
149
150impl Default for JsonSchemaValidator {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156impl ToolValidator for JsonSchemaValidator {
157 fn validate(&self, request: &ToolRequest) -> bool {
158 if let Some(schema) = self.schemas.get(&request.name) {
160 if let Some(required) = schema.get("required") {
161 if let Some(required_arr) = required.as_array() {
162 for field in required_arr {
163 if let Some(field_name) = field.as_str() {
164 if request.arguments.get(field_name).is_none() {
165 return false;
166 }
167 }
168 }
169 }
170 }
171 }
172 true
173 }
174
175 fn validate_with_reason(&self, request: &ToolRequest) -> Result<(), String> {
176 if self.validate(request) {
177 Ok(())
178 } else {
179 Err(format!("Validation failed for tool: {}", request.name))
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[tokio::test]
189 async fn test_executor_creation() {
190 let executor = DefaultToolExecutor::new();
191 let tools = executor.list_tools();
193 assert!(!tools.is_empty(), "Expected tools to be registered");
194 assert!(
196 tools.len() >= 5,
197 "Expected at least 5 basic tools, got {}",
198 tools.len()
199 );
200 }
201
202 #[test]
203 fn test_validator() {
204 let validator = JsonSchemaValidator::new();
205 let request = ToolRequest {
206 call_id: "1".to_string(),
207 name: "test".to_string(),
208 arguments: serde_json::json!({"path": "test"}),
209 };
210 assert!(validator.validate(&request));
211 }
212}