1use std::time::Instant;
26use tokio::time::timeout;
27use tracing::{debug, error, info, warn};
28
29use crate::tool::{Capability, ToolRegistry};
30use crate::tool_error::ToolError;
31use crate::tool_result::ToolResult;
32
33pub struct ToolExecutor {
57 registry: ToolRegistry,
58 audit_enabled: bool,
60 max_parallel: usize,
62 allowed_capabilities: Vec<Capability>,
64}
65
66impl ToolExecutor {
67 pub fn new(registry: ToolRegistry) -> Self {
69 Self {
70 registry,
71 audit_enabled: true,
72 max_parallel: 0, allowed_capabilities: vec![
74 Capability::PureComputation,
75 Capability::Network,
76 Capability::FileSystem,
77 Capability::Subprocess,
78 Capability::Environment,
79 Capability::Cryptography,
80 ],
81 }
82 }
83
84 pub fn without_audit(registry: ToolRegistry) -> Self {
86 Self {
87 registry,
88 audit_enabled: false,
89 max_parallel: 0,
90 allowed_capabilities: vec![
91 Capability::PureComputation,
92 Capability::Network,
93 Capability::FileSystem,
94 Capability::Subprocess,
95 Capability::Environment,
96 Capability::Cryptography,
97 ],
98 }
99 }
100
101 pub fn with_max_parallel(mut self, max: usize) -> Self {
103 self.max_parallel = max;
104 self
105 }
106
107 pub fn with_allowed_capabilities(mut self, caps: Vec<Capability>) -> Self {
109 self.allowed_capabilities = caps;
110 self
111 }
112
113 pub fn register_wasm_tool(
115 &mut self,
116 definition: crate::tool::ToolDefinition,
117 module_bytes: Vec<u8>,
118 capabilities: Vec<Capability>,
119 ) {
120 use crate::wasm_tool::WasmTool;
121 use std::sync::Arc;
122 let tool = WasmTool::new(definition, module_bytes, capabilities);
123 self.registry.register(Arc::new(tool));
124 }
125
126 pub fn register_wasm_tool_from_file(
128 &mut self,
129 definition: crate::tool::ToolDefinition,
130 path: impl AsRef<std::path::Path>,
131 capabilities: Vec<Capability>,
132 ) -> Result<(), std::io::Error> {
133 let bytes = std::fs::read(path)?;
134 self.register_wasm_tool(definition, bytes, capabilities);
135 Ok(())
136 }
137
138 pub async fn execute(
156 &self,
157 tool_name: &str,
158 args: serde_json::Value,
159 ) -> Result<ToolResult, ToolError> {
160 let tool = self.registry.get(tool_name).ok_or_else(|| {
162 warn!(tool = tool_name, "Tool not found");
163 ToolError::not_found(tool_name)
164 })?;
165
166 for cap in tool.capabilities() {
168 if !self.allowed_capabilities.contains(&cap) {
169 warn!(
170 tool = tool_name,
171 capability = ?cap,
172 "Tool requires missing capability"
173 );
174 return Err(ToolError::unavailable(
175 tool_name,
176 format!("Sandbox violation: tool requires {:?}", cap),
177 ));
178 }
179 }
180
181 if !tool.is_available() {
183 warn!(tool = tool_name, "Tool is unavailable");
184 return Err(ToolError::unavailable(
185 tool_name,
186 "Tool is currently disabled",
187 ));
188 }
189
190 debug!(tool = tool_name, "Validating arguments against schema");
192 let schema_str = tool.definition().parameters;
193 if !schema_str.is_empty() && schema_str != "{}" {
194 let schema_json: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
195 ToolError::execution_failed(tool_name, format!("Invalid tool schema: {}", e))
196 })?;
197
198 let compiled = jsonschema::JSONSchema::compile(&schema_json).map_err(|e| {
199 ToolError::execution_failed(
200 tool_name,
201 format!("Failed to compile tool schema: {}", e),
202 )
203 })?;
204
205 if !compiled.is_valid(&args) {
206 warn!(tool = tool_name, "Schema validation failed");
207 return Err(ToolError::invalid_args(
208 tool_name,
209 "Arguments do not match tool schema",
210 ));
211 }
212 }
213
214 debug!(tool = tool_name, "Running custom validation");
216 tool.validate(&args)?;
217
218 let tool_timeout = tool.timeout();
220 let start = Instant::now();
221
222 debug!(
223 tool = tool_name,
224 timeout_ms = tool_timeout.as_millis(),
225 "Executing tool"
226 );
227
228 let output = timeout(tool_timeout, tool.execute(args.clone()))
229 .await
230 .map_err(|_| {
231 error!(
232 tool = tool_name,
233 timeout_ms = tool_timeout.as_millis(),
234 "Tool execution timed out"
235 );
236 ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
237 })??;
238
239 let elapsed = start.elapsed();
240
241 let result = ToolResult::new(tool_name, &args, output, elapsed);
243
244 info!(
246 tool = tool_name,
247 execution_ms = elapsed.as_millis(),
248 hash = %result.hash,
249 "Tool executed successfully"
250 );
251
252 if self.audit_enabled {
255 debug!(
256 tool = tool_name,
257 result_hash = %result.hash,
258 "Audit entry created"
259 );
260 }
261
262 Ok(result)
263 }
264
265 pub async fn execute_parallel(
281 &self,
282 calls: Vec<(String, serde_json::Value)>,
283 ) -> Vec<Result<ToolResult, ToolError>> {
284 debug!(count = calls.len(), "Executing tools in parallel");
285
286 if self.max_parallel > 0 {
287 use futures::stream::{self, StreamExt};
288
289 stream::iter(calls)
290 .map(|(name, args)| async move { self.execute(&name, args).await })
291 .buffered(self.max_parallel)
292 .collect()
293 .await
294 } else {
295 let futures: Vec<_> = calls
296 .into_iter()
297 .map(|(name, args)| {
298 async move { self.execute(&name, args).await }
300 })
301 .collect();
302
303 futures::future::join_all(futures).await
304 }
305 }
306
307 pub fn registry(&self) -> &ToolRegistry {
309 &self.registry
310 }
311
312 pub fn registry_mut(&mut self) -> &mut ToolRegistry {
314 &mut self.registry
315 }
316
317 pub fn has_tool(&self, name: &str) -> bool {
319 self.registry.contains(name)
320 }
321
322 pub fn tool_names(&self) -> Vec<&str> {
324 self.registry.names()
325 }
326}
327
328impl std::fmt::Debug for ToolExecutor {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.debug_struct("ToolExecutor")
331 .field("tools", &self.registry.names())
332 .field("audit_enabled", &self.audit_enabled)
333 .field("max_parallel", &self.max_parallel)
334 .finish()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::tool::{Tool, ToolDefinition};
342 use async_trait::async_trait;
343 use std::sync::Arc;
344 use std::time::Duration;
345
346 struct EchoTool {
348 definition: ToolDefinition,
349 }
350
351 impl EchoTool {
352 fn new() -> Self {
353 Self {
354 definition: ToolDefinition::new(
355 "echo",
356 "Echo back the input",
357 r#"{"type": "object"}"#,
358 ),
359 }
360 }
361 }
362
363 #[async_trait]
364 impl Tool for EchoTool {
365 fn definition(&self) -> &ToolDefinition {
366 &self.definition
367 }
368
369 async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
370 Ok(serde_json::json!({ "echo": args }))
371 }
372 }
373
374 struct FailTool {
376 definition: ToolDefinition,
377 }
378
379 impl FailTool {
380 fn new() -> Self {
381 Self {
382 definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
383 }
384 }
385 }
386
387 #[async_trait]
388 impl Tool for FailTool {
389 fn definition(&self) -> &ToolDefinition {
390 &self.definition
391 }
392
393 async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
394 Err(ToolError::execution_failed("fail", "Intentional failure"))
395 }
396 }
397
398 struct SlowTool {
400 definition: ToolDefinition,
401 }
402
403 impl SlowTool {
404 fn new() -> Self {
405 Self {
406 definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
407 }
408 }
409 }
410
411 #[async_trait]
412 impl Tool for SlowTool {
413 fn definition(&self) -> &ToolDefinition {
414 &self.definition
415 }
416
417 fn timeout(&self) -> Duration {
418 Duration::from_millis(50) }
420
421 async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
422 tokio::time::sleep(Duration::from_secs(10)).await;
423 Ok(serde_json::json!({"done": true}))
424 }
425 }
426
427 #[tokio::test]
428 async fn test_execute_success() {
429 let mut registry = ToolRegistry::new();
430 registry.register(Arc::new(EchoTool::new()));
431
432 let executor = ToolExecutor::new(registry);
433 let result = executor
434 .execute("echo", serde_json::json!({"message": "hello"}))
435 .await
436 .unwrap();
437
438 assert_eq!(result.tool_name, "echo");
439 assert!(result.output["echo"]["message"] == "hello");
440 assert!(!result.hash.to_string().is_empty());
441 }
442
443 #[tokio::test]
444 async fn test_execute_not_found() {
445 let registry = ToolRegistry::new();
446 let executor = ToolExecutor::new(registry);
447
448 let result = executor.execute("nonexistent", serde_json::json!({})).await;
449
450 assert!(matches!(result, Err(ToolError::NotFound { .. })));
451 }
452
453 #[tokio::test]
454 async fn test_execute_failure() {
455 let mut registry = ToolRegistry::new();
456 registry.register(Arc::new(FailTool::new()));
457
458 let executor = ToolExecutor::new(registry);
459 let result = executor.execute("fail", serde_json::json!({})).await;
460
461 assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
462 }
463
464 #[tokio::test]
465 async fn test_execute_timeout() {
466 let mut registry = ToolRegistry::new();
467 registry.register(Arc::new(SlowTool::new()));
468
469 let executor = ToolExecutor::new(registry);
470 let result = executor.execute("slow", serde_json::json!({})).await;
471
472 assert!(matches!(result, Err(ToolError::Timeout { .. })));
473 }
474
475 #[tokio::test]
476 async fn test_execute_parallel() {
477 let mut registry = ToolRegistry::new();
478 registry.register(Arc::new(EchoTool::new()));
479
480 let executor = ToolExecutor::new(registry);
481
482 let calls = vec![
483 ("echo".to_string(), serde_json::json!({"n": 1})),
484 ("echo".to_string(), serde_json::json!({"n": 2})),
485 ("echo".to_string(), serde_json::json!({"n": 3})),
486 ];
487
488 let results = executor.execute_parallel(calls).await;
489
490 assert_eq!(results.len(), 3);
491 assert!(results.iter().all(|r| r.is_ok()));
492 }
493
494 #[tokio::test]
495 async fn test_has_tool() {
496 let mut registry = ToolRegistry::new();
497 registry.register(Arc::new(EchoTool::new()));
498
499 let executor = ToolExecutor::new(registry);
500
501 assert!(executor.has_tool("echo"));
502 assert!(!executor.has_tool("nonexistent"));
503 }
504}