Skip to main content

vex_llm/
tool_executor.rs

1//! Tool Executor with Merkle audit integration
2//!
3//! This module provides `ToolExecutor` which wraps tool execution with:
4//! - Timeout protection (DoS prevention)
5//! - Input validation
6//! - Cryptographic result hashing
7//! - Merkle audit trail integration
8//!
9//! # VEX Innovation
10//!
11//! Every tool execution is automatically logged to the audit chain with:
12//! - Tool name and argument hash (not raw args for privacy)
13//! - Result hash for verification
14//! - Execution time metrics
15//!
16//! This enables cryptographic proof of what tools were used.
17//!
18//! # Security Considerations
19//!
20//! - All executions have configurable timeouts
21//! - Validation runs before execution
22//! - Audit logging is non-fatal (doesn't break execution)
23//! - Arguments are hashed before logging (privacy protection)
24
25use std::time::Instant;
26use tokio::time::timeout;
27use tracing::{debug, error, info, warn};
28
29use crate::tool::{Capability, ToolRegistry};
30use crate::tool_error::ToolError;
31use crate::tool_result::ToolResult;
32
33/// Tool executor with automatic audit logging and timeout protection.
34///
35/// The executor provides a safe, audited interface to execute tools:
36/// 1. Validates arguments against tool's schema
37/// 2. Executes with timeout protection
38/// 3. Hashes results for Merkle chain
39/// 4. Logs execution to audit trail (if configured)
40///
41/// # Example
42///
43/// ```ignore
44/// use vex_llm::{ToolExecutor, ToolRegistry};
45///
46/// let registry = ToolRegistry::with_builtins();
47/// let executor = ToolExecutor::new(registry);
48///
49/// let result = executor
50///     .execute("calculator", json!({"expression": "2+2"}))
51///     .await?;
52///
53/// println!("Result: {}", result.output);
54/// println!("Hash: {}", result.hash);
55/// ```
56pub struct ToolExecutor {
57    registry: ToolRegistry,
58    /// Enable/disable audit logging
59    audit_enabled: bool,
60    /// Maximum parallel executions (0 = unlimited)
61    max_parallel: usize,
62    /// Allowed capabilities for this executor (Security Sandbox)
63    allowed_capabilities: Vec<Capability>,
64}
65
66impl ToolExecutor {
67    /// Create a new executor with the given registry
68    pub fn new(registry: ToolRegistry) -> Self {
69        Self {
70            registry,
71            audit_enabled: true,
72            max_parallel: 0, // Unlimited by default
73            allowed_capabilities: vec![
74                Capability::PureComputation,
75                Capability::Network,
76                Capability::FileSystem,
77                Capability::Subprocess,
78                Capability::Environment,
79                Capability::Cryptography,
80            ],
81        }
82    }
83
84    /// Create executor with audit logging disabled
85    pub fn without_audit(registry: ToolRegistry) -> Self {
86        Self {
87            registry,
88            audit_enabled: false,
89            max_parallel: 0,
90            allowed_capabilities: vec![
91                Capability::PureComputation,
92                Capability::Network,
93                Capability::FileSystem,
94                Capability::Subprocess,
95                Capability::Environment,
96                Capability::Cryptography,
97            ],
98        }
99    }
100
101    /// Set maximum parallel executions
102    pub fn with_max_parallel(mut self, max: usize) -> Self {
103        self.max_parallel = max;
104        self
105    }
106
107    /// Set allowed capabilities for the sandbox
108    pub fn with_allowed_capabilities(mut self, caps: Vec<Capability>) -> Self {
109        self.allowed_capabilities = caps;
110        self
111    }
112
113    /// Register a WASM tool from bytes
114    pub fn register_wasm_tool(
115        &mut self,
116        definition: crate::tool::ToolDefinition,
117        module_bytes: Vec<u8>,
118        capabilities: Vec<Capability>,
119    ) {
120        use crate::wasm_tool::WasmTool;
121        use std::sync::Arc;
122        let tool = WasmTool::new(definition, module_bytes, capabilities);
123        self.registry.register(Arc::new(tool));
124    }
125
126    /// Register a WASM tool from a file path
127    pub fn register_wasm_tool_from_file(
128        &mut self,
129        definition: crate::tool::ToolDefinition,
130        path: impl AsRef<std::path::Path>,
131        capabilities: Vec<Capability>,
132    ) -> Result<(), std::io::Error> {
133        let bytes = std::fs::read(path)?;
134        self.register_wasm_tool(definition, bytes, capabilities);
135        Ok(())
136    }
137
138    /// Execute a tool by name with given arguments.
139    ///
140    /// # Arguments
141    ///
142    /// * `tool_name` - Name of the tool to execute
143    /// * `args` - JSON arguments to pass to the tool
144    ///
145    /// # Returns
146    ///
147    /// * `Ok(ToolResult)` - Execution result with hash for verification
148    /// * `Err(ToolError)` - If tool not found, validation failed, execution error, or timeout
149    ///
150    /// # Security
151    ///
152    /// - Tool lookup prevents arbitrary code execution
153    /// - Timeout prevents DoS from hanging tools
154    /// - Result hash enables tamper detection
155    pub async fn execute(
156        &self,
157        tool_name: &str,
158        args: serde_json::Value,
159    ) -> Result<ToolResult, ToolError> {
160        // 1. Get tool from registry
161        let tool = self.registry.get(tool_name).ok_or_else(|| {
162            warn!(tool = tool_name, "Tool not found");
163            ToolError::not_found(tool_name)
164        })?;
165
166        // 1.5. Check capabilities (Sandbox)
167        for cap in tool.capabilities() {
168            if !self.allowed_capabilities.contains(&cap) {
169                warn!(
170                    tool = tool_name,
171                    capability = ?cap,
172                    "Tool requires missing capability"
173                );
174                return Err(ToolError::unavailable(
175                    tool_name,
176                    format!("Sandbox violation: tool requires {:?}", cap),
177                ));
178            }
179        }
180
181        // 2. Check availability
182        if !tool.is_available() {
183            warn!(tool = tool_name, "Tool is unavailable");
184            return Err(ToolError::unavailable(
185                tool_name,
186                "Tool is currently disabled",
187            ));
188        }
189
190        // 3. Validate arguments against JSON Schema
191        debug!(tool = tool_name, "Validating arguments against schema");
192        let schema_str = tool.definition().parameters;
193        if !schema_str.is_empty() && schema_str != "{}" {
194            let schema_json: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
195                ToolError::execution_failed(tool_name, format!("Invalid tool schema: {}", e))
196            })?;
197
198            let compiled = jsonschema::JSONSchema::compile(&schema_json).map_err(|e| {
199                ToolError::execution_failed(
200                    tool_name,
201                    format!("Failed to compile tool schema: {}", e),
202                )
203            })?;
204
205            if !compiled.is_valid(&args) {
206                warn!(tool = tool_name, "Schema validation failed");
207                return Err(ToolError::invalid_args(
208                    tool_name,
209                    "Arguments do not match tool schema",
210                ));
211            }
212        }
213
214        // 4. Custom validation
215        debug!(tool = tool_name, "Running custom validation");
216        tool.validate(&args)?;
217
218        // 5. Execute with timeout
219        let tool_timeout = tool.timeout();
220        let start = Instant::now();
221
222        debug!(
223            tool = tool_name,
224            timeout_ms = tool_timeout.as_millis(),
225            "Executing tool"
226        );
227
228        let output = timeout(tool_timeout, tool.execute(args.clone()))
229            .await
230            .map_err(|_| {
231                error!(
232                    tool = tool_name,
233                    timeout_ms = tool_timeout.as_millis(),
234                    "Tool execution timed out"
235                );
236                ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
237            })??;
238
239        let elapsed = start.elapsed();
240
241        // 5. Create result with cryptographic hash
242        let result = ToolResult::new(tool_name, &args, output, elapsed);
243
244        // 6. Log execution metrics
245        info!(
246            tool = tool_name,
247            execution_ms = elapsed.as_millis(),
248            hash = %result.hash,
249            "Tool executed successfully"
250        );
251
252        // 7. Audit logging would happen here (integration point)
253        // Note: We log to tracing; actual AuditStore integration is in the runtime
254        if self.audit_enabled {
255            debug!(
256                tool = tool_name,
257                result_hash = %result.hash,
258                "Audit entry created"
259            );
260        }
261
262        Ok(result)
263    }
264
265    /// Execute multiple tools in parallel.
266    ///
267    /// # Arguments
268    ///
269    /// * `calls` - Vector of (tool_name, args) pairs
270    ///
271    /// # Returns
272    ///
273    /// Vector of results in the same order as input.
274    /// Each result is independent (one failure doesn't affect others).
275    ///
276    /// # Security
277    ///
278    /// - Respects max_parallel limit to prevent resource exhaustion
279    /// - Each tool has its own timeout
280    pub async fn execute_parallel(
281        &self,
282        calls: Vec<(String, serde_json::Value)>,
283    ) -> Vec<Result<ToolResult, ToolError>> {
284        debug!(count = calls.len(), "Executing tools in parallel");
285
286        if self.max_parallel > 0 {
287            use futures::stream::{self, StreamExt};
288
289            stream::iter(calls)
290                .map(|(name, args)| async move { self.execute(&name, args).await })
291                .buffered(self.max_parallel)
292                .collect()
293                .await
294        } else {
295            let futures: Vec<_> = calls
296                .into_iter()
297                .map(|(name, args)| {
298                    // Create an owned future that doesn't borrow the iterator
299                    async move { self.execute(&name, args).await }
300                })
301                .collect();
302
303            futures::future::join_all(futures).await
304        }
305    }
306
307    /// Get a reference to the tool registry
308    pub fn registry(&self) -> &ToolRegistry {
309        &self.registry
310    }
311
312    /// Get a mutable reference to the tool registry
313    pub fn registry_mut(&mut self) -> &mut ToolRegistry {
314        &mut self.registry
315    }
316
317    /// Check if a tool exists
318    pub fn has_tool(&self, name: &str) -> bool {
319        self.registry.contains(name)
320    }
321
322    /// List all available tool names
323    pub fn tool_names(&self) -> Vec<&str> {
324        self.registry.names()
325    }
326}
327
328impl std::fmt::Debug for ToolExecutor {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        f.debug_struct("ToolExecutor")
331            .field("tools", &self.registry.names())
332            .field("audit_enabled", &self.audit_enabled)
333            .field("max_parallel", &self.max_parallel)
334            .finish()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::tool::{Tool, ToolDefinition};
342    use async_trait::async_trait;
343    use std::sync::Arc;
344    use std::time::Duration;
345
346    // Test tool that returns arguments
347    struct EchoTool {
348        definition: ToolDefinition,
349    }
350
351    impl EchoTool {
352        fn new() -> Self {
353            Self {
354                definition: ToolDefinition::new(
355                    "echo",
356                    "Echo back the input",
357                    r#"{"type": "object"}"#,
358                ),
359            }
360        }
361    }
362
363    #[async_trait]
364    impl Tool for EchoTool {
365        fn definition(&self) -> &ToolDefinition {
366            &self.definition
367        }
368
369        async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
370            Ok(serde_json::json!({ "echo": args }))
371        }
372    }
373
374    // Test tool that always fails
375    struct FailTool {
376        definition: ToolDefinition,
377    }
378
379    impl FailTool {
380        fn new() -> Self {
381            Self {
382                definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
383            }
384        }
385    }
386
387    #[async_trait]
388    impl Tool for FailTool {
389        fn definition(&self) -> &ToolDefinition {
390            &self.definition
391        }
392
393        async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
394            Err(ToolError::execution_failed("fail", "Intentional failure"))
395        }
396    }
397
398    // Test tool that times out
399    struct SlowTool {
400        definition: ToolDefinition,
401    }
402
403    impl SlowTool {
404        fn new() -> Self {
405            Self {
406                definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
407            }
408        }
409    }
410
411    #[async_trait]
412    impl Tool for SlowTool {
413        fn definition(&self) -> &ToolDefinition {
414            &self.definition
415        }
416
417        fn timeout(&self) -> Duration {
418            Duration::from_millis(50) // Very short timeout for testing
419        }
420
421        async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
422            tokio::time::sleep(Duration::from_secs(10)).await;
423            Ok(serde_json::json!({"done": true}))
424        }
425    }
426
427    #[tokio::test]
428    async fn test_execute_success() {
429        let mut registry = ToolRegistry::new();
430        registry.register(Arc::new(EchoTool::new()));
431
432        let executor = ToolExecutor::new(registry);
433        let result = executor
434            .execute("echo", serde_json::json!({"message": "hello"}))
435            .await
436            .unwrap();
437
438        assert_eq!(result.tool_name, "echo");
439        assert!(result.output["echo"]["message"] == "hello");
440        assert!(!result.hash.to_string().is_empty());
441    }
442
443    #[tokio::test]
444    async fn test_execute_not_found() {
445        let registry = ToolRegistry::new();
446        let executor = ToolExecutor::new(registry);
447
448        let result = executor.execute("nonexistent", serde_json::json!({})).await;
449
450        assert!(matches!(result, Err(ToolError::NotFound { .. })));
451    }
452
453    #[tokio::test]
454    async fn test_execute_failure() {
455        let mut registry = ToolRegistry::new();
456        registry.register(Arc::new(FailTool::new()));
457
458        let executor = ToolExecutor::new(registry);
459        let result = executor.execute("fail", serde_json::json!({})).await;
460
461        assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
462    }
463
464    #[tokio::test]
465    async fn test_execute_timeout() {
466        let mut registry = ToolRegistry::new();
467        registry.register(Arc::new(SlowTool::new()));
468
469        let executor = ToolExecutor::new(registry);
470        let result = executor.execute("slow", serde_json::json!({})).await;
471
472        assert!(matches!(result, Err(ToolError::Timeout { .. })));
473    }
474
475    #[tokio::test]
476    async fn test_execute_parallel() {
477        let mut registry = ToolRegistry::new();
478        registry.register(Arc::new(EchoTool::new()));
479
480        let executor = ToolExecutor::new(registry);
481
482        let calls = vec![
483            ("echo".to_string(), serde_json::json!({"n": 1})),
484            ("echo".to_string(), serde_json::json!({"n": 2})),
485            ("echo".to_string(), serde_json::json!({"n": 3})),
486        ];
487
488        let results = executor.execute_parallel(calls).await;
489
490        assert_eq!(results.len(), 3);
491        assert!(results.iter().all(|r| r.is_ok()));
492    }
493
494    #[tokio::test]
495    async fn test_has_tool() {
496        let mut registry = ToolRegistry::new();
497        registry.register(Arc::new(EchoTool::new()));
498
499        let executor = ToolExecutor::new(registry);
500
501        assert!(executor.has_tool("echo"));
502        assert!(!executor.has_tool("nonexistent"));
503    }
504}