tool_orchestrator/
engine.rs

1//! Rhai engine setup and tool orchestration.
2//!
3//! This module contains the core [`ToolOrchestrator`] struct that executes
4//! Rhai scripts with access to registered tools. It implements Anthropic's
5//! "Programmatic Tool Calling" pattern.
6//!
7//! # Architecture
8//!
9//! The orchestrator uses feature-gated thread-safety primitives:
10//!
11//! - **`native`** feature: Uses `Arc<Mutex<T>>` for thread-safe execution
12//! - **`wasm`** feature: Uses `Rc<RefCell<T>>` for single-threaded WASM
13//!
14//! # Key Components
15//!
16//! - [`ToolOrchestrator`] - Main entry point for script execution
17//! - [`ToolExecutor`] - Type alias for tool callback functions
18//! - [`dynamic_to_json`] - Converts Rhai values to JSON for tool input
19//!
20//! # Example
21//!
22//! ```ignore
23//! use tool_orchestrator::{ToolOrchestrator, ExecutionLimits};
24//!
25//! let mut orchestrator = ToolOrchestrator::new();
26//!
27//! // Register a tool
28//! orchestrator.register_executor("greet", |input| {
29//!     let name = input.as_str().unwrap_or("world");
30//!     Ok(format!("Hello, {}!", name))
31//! });
32//!
33//! // Execute a script that uses the tool
34//! let result = orchestrator.execute(
35//!     r#"greet("Claude")"#,
36//!     ExecutionLimits::default()
37//! )?;
38//!
39//! assert_eq!(result.output, "Hello, Claude!");
40//! ```
41//!
42//! # Security
43//!
44//! The Rhai engine is sandboxed by default with no access to:
45//! - File system
46//! - Network
47//! - Shell commands
48//! - System time (except via provided primitives)
49//!
50//! All resource limits are enforced via [`ExecutionLimits`].
51
52use std::collections::HashMap;
53
54#[cfg(feature = "native")]
55use std::sync::{Arc, Mutex};
56#[cfg(feature = "native")]
57use std::time::Instant;
58
59#[cfg(feature = "wasm")]
60use std::cell::RefCell;
61#[cfg(feature = "wasm")]
62use std::rc::Rc;
63#[cfg(feature = "wasm")]
64use web_time::Instant;
65
66use rhai::{Engine, EvalAltResult, Scope};
67
68use crate::sandbox::ExecutionLimits;
69use crate::types::{OrchestratorError, OrchestratorResult, ToolCall};
70
71// ============================================================================
72// Engine Configuration Constants
73// ============================================================================
74
75/// Maximum expression nesting depth (prevents stack overflow from deeply nested expressions)
76const MAX_EXPR_DEPTH: usize = 64;
77
78/// Maximum function call nesting depth (prevents stack overflow from deep recursion)
79const MAX_CALL_DEPTH: usize = 64;
80
81// ============================================================================
82// Type aliases for thread-safety primitives (feature-gated)
83// ============================================================================
84
85/// Thread-safe vector wrapper (native: `Arc<Mutex<Vec<T>>>`)
86#[cfg(feature = "native")]
87pub type SharedVec<T> = Arc<Mutex<Vec<T>>>;
88
89/// Thread-safe counter wrapper (native: `Arc<Mutex<usize>>`)
90#[cfg(feature = "native")]
91pub type SharedCounter = Arc<Mutex<usize>>;
92
93/// Tool executor function type (native: thread-safe `Arc<dyn Fn>`)
94///
95/// Tools receive JSON input and return either a success string or error string.
96///
97/// # Example
98///
99/// ```ignore
100/// orchestrator.register_executor("my_tool", |input: serde_json::Value| {
101///     // Process input and return result
102///     Ok("result".to_string())
103/// });
104/// ```
105#[cfg(feature = "native")]
106pub type ToolExecutor = Arc<dyn Fn(serde_json::Value) -> Result<String, String> + Send + Sync>;
107
108/// Single-threaded vector wrapper (WASM: `Rc<RefCell<Vec<T>>>`)
109#[cfg(feature = "wasm")]
110pub type SharedVec<T> = Rc<RefCell<Vec<T>>>;
111
112/// Single-threaded counter wrapper (WASM: `Rc<RefCell<usize>>`)
113#[cfg(feature = "wasm")]
114pub type SharedCounter = Rc<RefCell<usize>>;
115
116/// Tool executor function type (WASM: single-threaded `Rc<dyn Fn>`)
117///
118/// Tools receive JSON input and return either a success string or error string.
119#[cfg(feature = "wasm")]
120pub type ToolExecutor = Rc<dyn Fn(serde_json::Value) -> Result<String, String>>;
121
122// ============================================================================
123// Helper functions for shared state (feature-gated)
124// ============================================================================
125//
126// These functions abstract over the difference between native (Arc/Mutex)
127// and WASM (Rc/RefCell) shared state primitives, allowing the main code
128// to be feature-agnostic.
129
130#[cfg(feature = "native")]
131fn new_shared_vec<T>() -> SharedVec<T> {
132    Arc::new(Mutex::new(Vec::new()))
133}
134
135#[cfg(feature = "wasm")]
136fn new_shared_vec<T>() -> SharedVec<T> {
137    Rc::new(RefCell::new(Vec::new()))
138}
139
140#[cfg(feature = "native")]
141fn new_shared_counter() -> SharedCounter {
142    Arc::new(Mutex::new(0))
143}
144
145#[cfg(feature = "wasm")]
146fn new_shared_counter() -> SharedCounter {
147    Rc::new(RefCell::new(0))
148}
149
150#[cfg(feature = "native")]
151fn clone_shared<T: ?Sized>(shared: &Arc<T>) -> Arc<T> {
152    Arc::clone(shared)
153}
154
155#[cfg(feature = "wasm")]
156fn clone_shared<T: ?Sized>(shared: &Rc<T>) -> Rc<T> {
157    Rc::clone(shared)
158}
159
160#[cfg(feature = "native")]
161fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
162    shared.lock().unwrap().clone()
163}
164
165#[cfg(feature = "wasm")]
166fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
167    shared.borrow().clone()
168}
169
170#[cfg(feature = "native")]
171fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
172    shared.lock().unwrap().push(item);
173}
174
175#[cfg(feature = "wasm")]
176fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
177    shared.borrow_mut().push(item);
178}
179
180#[cfg(feature = "native")]
181fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
182    let mut c = shared.lock().unwrap();
183    if *c >= max {
184        return Err(());
185    }
186    *c += 1;
187    drop(c); // Release lock early to avoid unnecessary contention
188    Ok(())
189}
190
191#[cfg(feature = "wasm")]
192fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
193    let mut c = shared.borrow_mut();
194    if *c >= max {
195        return Err(());
196    }
197    *c += 1;
198    Ok(())
199}
200
201// ============================================================================
202// ToolOrchestrator
203// ============================================================================
204
205/// Tool orchestrator - executes Rhai scripts with registered tool access.
206///
207/// The `ToolOrchestrator` is the main entry point for programmatic tool calling.
208/// It manages tool registration and script execution within a sandboxed Rhai
209/// environment.
210///
211/// # Features
212///
213/// - **Tool Registration**: Register Rust functions as callable tools
214/// - **Script Execution**: Run Rhai scripts that can invoke registered tools
215/// - **Resource Limits**: Configurable limits prevent runaway execution
216/// - **Audit Trail**: All tool calls are logged with timing information
217///
218/// # Thread Safety
219///
220/// - With the `native` feature, the orchestrator is thread-safe
221/// - With the `wasm` feature, it's single-threaded for WASM compatibility
222///
223/// # Example
224///
225/// ```ignore
226/// use tool_orchestrator::{ToolOrchestrator, ExecutionLimits};
227///
228/// let mut orchestrator = ToolOrchestrator::new();
229///
230/// // Register tools
231/// orchestrator.register_executor("add", |input| {
232///     let arr = input.as_array().unwrap();
233///     let sum: i64 = arr.iter().filter_map(|v| v.as_i64()).sum();
234///     Ok(sum.to_string())
235/// });
236///
237/// // Execute script
238/// let result = orchestrator.execute(
239///     r#"
240///     let a = add([1, 2, 3]);
241///     let b = add([4, 5, 6]);
242///     `Sum: ${a} + ${b}`
243///     "#,
244///     ExecutionLimits::default()
245/// )?;
246///
247/// println!("{}", result.output);  // "Sum: 6 + 15"
248/// println!("Tool calls: {}", result.tool_calls.len());  // 2
249/// ```
250pub struct ToolOrchestrator {
251    #[allow(dead_code)]
252    engine: Engine,
253    executors: HashMap<String, ToolExecutor>,
254}
255
256impl ToolOrchestrator {
257    /// Create a new tool orchestrator with default settings.
258    ///
259    /// Initializes a fresh Rhai engine with expression depth limits
260    /// and an empty tool registry.
261    #[must_use]
262    pub fn new() -> Self {
263        let mut engine = Engine::new();
264
265        // Limit expression nesting depth to prevent stack overflow
266        engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
267
268        Self {
269            engine,
270            executors: HashMap::new(),
271        }
272    }
273
274    /// Register a tool executor function (native version - thread-safe).
275    ///
276    /// The executor function receives JSON input from the Rhai script and
277    /// returns either a success string or an error string.
278    ///
279    /// # Arguments
280    ///
281    /// * `name` - The name the tool will be callable as in Rhai scripts
282    /// * `executor` - Function that processes tool calls
283    ///
284    /// # Example
285    ///
286    /// ```ignore
287    /// orchestrator.register_executor("fetch_user", |input| {
288    ///     let user_id = input.as_i64().ok_or("Expected user ID")?;
289    ///     // Fetch user from database...
290    ///     Ok(format!(r#"{{"id": {}, "name": "Alice"}}"#, user_id))
291    /// });
292    /// ```
293    #[cfg(feature = "native")]
294    pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
295    where
296        F: Fn(serde_json::Value) -> Result<String, String> + Send + Sync + 'static,
297    {
298        self.executors.insert(name.into(), Arc::new(executor));
299    }
300
301    /// Register a tool executor function (WASM version - single-threaded).
302    ///
303    /// See the native version for full documentation.
304    #[cfg(feature = "wasm")]
305    pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
306    where
307        F: Fn(serde_json::Value) -> Result<String, String> + 'static,
308    {
309        self.executors.insert(name.into(), Rc::new(executor));
310    }
311
312    /// Execute a Rhai script with access to registered tools.
313    ///
314    /// Compiles and runs the provided Rhai script, making all registered
315    /// tools available as callable functions. Execution is bounded by the
316    /// provided [`ExecutionLimits`].
317    ///
318    /// # Arguments
319    ///
320    /// * `script` - Rhai source code to execute
321    /// * `limits` - Resource limits for this execution
322    ///
323    /// # Returns
324    ///
325    /// On success, returns [`OrchestratorResult`] containing:
326    /// - The script's output (final expression value)
327    /// - A log of all tool calls made
328    /// - Execution timing information
329    ///
330    /// # Errors
331    ///
332    /// Returns [`OrchestratorError`] if:
333    /// - Script fails to compile ([`CompilationError`])
334    /// - Script throws a runtime error ([`ExecutionError`])
335    /// - Operation limit exceeded ([`MaxOperationsExceeded`])
336    /// - Time limit exceeded ([`Timeout`])
337    ///
338    /// [`CompilationError`]: OrchestratorError::CompilationError
339    /// [`ExecutionError`]: OrchestratorError::ExecutionError
340    /// [`MaxOperationsExceeded`]: OrchestratorError::MaxOperationsExceeded
341    /// [`Timeout`]: OrchestratorError::Timeout
342    pub fn execute(
343        &self,
344        script: &str,
345        limits: ExecutionLimits,
346    ) -> Result<OrchestratorResult, OrchestratorError> {
347        let start_time = Instant::now();
348        let tool_calls: SharedVec<ToolCall> = new_shared_vec();
349        let call_count: SharedCounter = new_shared_counter();
350
351        // Create a new engine with limits for this execution
352        let mut engine = Engine::new();
353
354        // Apply resource limits from ExecutionLimits
355        engine.set_max_operations(limits.max_operations);
356        engine.set_max_string_size(limits.max_string_size);
357        engine.set_max_array_size(limits.max_array_size);
358        engine.set_max_map_size(limits.max_map_size);
359        engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
360
361        // Set up real-time timeout via on_progress callback
362        let timeout_ms = limits.timeout_ms;
363        let progress_start = Instant::now();
364        engine.on_progress(move |_ops| {
365            // Use saturating conversion - elapsed time exceeding u64::MAX is always a timeout
366            let elapsed = u64::try_from(progress_start.elapsed().as_millis()).unwrap_or(u64::MAX);
367            if elapsed > timeout_ms {
368                Some(rhai::Dynamic::from("timeout"))
369            } else {
370                None
371            }
372        });
373
374        // Register each tool as a Rhai function
375        for (name, executor) in &self.executors {
376            let exec = clone_shared(executor);
377            let calls = clone_shared(&tool_calls);
378            let count = clone_shared(&call_count);
379            let max_calls = limits.max_tool_calls;
380            let tool_name = name.clone();
381
382            // Register as a function that takes a Dynamic and returns a String
383            engine.register_fn(name.as_str(), move |input: rhai::Dynamic| -> String {
384                let call_start = Instant::now();
385
386                // Check call limit
387                if increment_counter(&count, max_calls).is_err() {
388                    return format!("ERROR: Maximum tool calls ({max_calls}) exceeded");
389                }
390
391                // Convert Dynamic to JSON
392                let json_input = dynamic_to_json(&input);
393
394                // Execute the tool
395                let (output, success) = match exec(json_input.clone()) {
396                    Ok(result) => (result, true),
397                    Err(e) => (format!("Tool error: {e}"), false),
398                };
399
400                // Record the call (saturate to u64::MAX for extremely long-running calls)
401                let duration_ms = u64::try_from(call_start.elapsed().as_millis()).unwrap_or(u64::MAX);
402                let call = ToolCall::new(
403                    tool_name.clone(),
404                    json_input,
405                    output.clone(),
406                    success,
407                    duration_ms,
408                );
409                push_to_vec(&calls, call);
410
411                output
412            });
413        }
414
415        // Compile the script
416        let ast = engine
417            .compile(script)
418            .map_err(|e| OrchestratorError::CompilationError(e.to_string()))?;
419
420        // Execute with timeout handling
421        let mut scope = Scope::new();
422        let result = engine
423            .eval_ast_with_scope::<rhai::Dynamic>(&mut scope, &ast)
424            .map_err(|e| match *e {
425                EvalAltResult::ErrorTooManyOperations(_) => {
426                    OrchestratorError::MaxOperationsExceeded(limits.max_operations)
427                }
428                EvalAltResult::ErrorTerminated(_, _) => {
429                    OrchestratorError::Timeout(limits.timeout_ms)
430                }
431                _ => OrchestratorError::ExecutionError(e.to_string()),
432            })?;
433
434        let execution_time_ms = u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
435
436        // Convert result to string
437        let output = if result.is_string() {
438            result.into_string().unwrap_or_default()
439        } else if result.is_unit() {
440            String::new()
441        } else {
442            format!("{result:?}")
443        };
444
445        let calls = lock_vec(&tool_calls);
446        Ok(OrchestratorResult::success(output, calls, execution_time_ms))
447    }
448
449    /// Get list of registered tool names.
450    ///
451    /// Returns the names of all tools that have been registered with
452    /// [`register_executor`]. These names are callable as functions
453    /// in Rhai scripts.
454    ///
455    /// [`register_executor`]: Self::register_executor
456    ///
457    /// # Example
458    ///
459    /// ```ignore
460    /// orchestrator.register_executor("tool_a", |_| Ok("a".into()));
461    /// orchestrator.register_executor("tool_b", |_| Ok("b".into()));
462    ///
463    /// let tools = orchestrator.registered_tools();
464    /// assert!(tools.contains(&"tool_a"));
465    /// assert!(tools.contains(&"tool_b"));
466    /// ```
467    #[must_use]
468    pub fn registered_tools(&self) -> Vec<&str> {
469        self.executors.keys().map(String::as_str).collect()
470    }
471}
472
473impl Default for ToolOrchestrator {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479// ============================================================================
480// Helper functions
481// ============================================================================
482
483/// Convert Rhai [`Dynamic`] value to [`serde_json::Value`].
484///
485/// This function handles the conversion of Rhai's dynamic type system to
486/// JSON for passing data to tool executors. Supports all common Rhai types:
487///
488/// - Strings → JSON strings
489/// - Integers → JSON numbers
490/// - Floats → JSON numbers
491/// - Booleans → JSON booleans
492/// - Arrays → JSON arrays (recursive)
493/// - Maps → JSON objects (recursive)
494/// - Unit → JSON null
495/// - Other → Debug string representation
496///
497/// # Example
498///
499/// ```ignore
500/// use rhai::Dynamic;
501/// use tool_orchestrator::dynamic_to_json;
502///
503/// let d = Dynamic::from("hello");
504/// let j = dynamic_to_json(&d);
505/// assert_eq!(j, serde_json::json!("hello"));
506/// ```
507///
508/// [`Dynamic`]: rhai::Dynamic
509pub fn dynamic_to_json(value: &rhai::Dynamic) -> serde_json::Value {
510    if value.is_string() {
511        serde_json::Value::String(value.clone().into_string().unwrap_or_default())
512    } else if value.is_int() {
513        serde_json::Value::Number(serde_json::Number::from(value.clone().as_int().unwrap_or(0)))
514    } else if value.is_float() {
515        serde_json::json!(value.clone().as_float().unwrap_or(0.0))
516    } else if value.is_bool() {
517        serde_json::Value::Bool(value.clone().as_bool().unwrap_or(false))
518    } else if value.is_array() {
519        let arr: Vec<rhai::Dynamic> = value.clone().into_array().unwrap_or_default();
520        serde_json::Value::Array(arr.iter().map(dynamic_to_json).collect())
521    } else if value.is_map() {
522        let map: rhai::Map = value.clone().cast();
523        let mut json_map = serde_json::Map::new();
524        for (k, v) in &map {
525            json_map.insert(k.to_string(), dynamic_to_json(v));
526        }
527        serde_json::Value::Object(json_map)
528    } else if value.is_unit() {
529        serde_json::Value::Null
530    } else {
531        serde_json::Value::String(format!("{value:?}"))
532    }
533}
534
535// ============================================================================
536// Tests
537// ============================================================================
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_orchestrator_creation() {
545        let orchestrator = ToolOrchestrator::new();
546        assert!(orchestrator.registered_tools().is_empty());
547    }
548
549    #[test]
550    fn test_register_executor() {
551        let mut orchestrator = ToolOrchestrator::new();
552        orchestrator.register_executor("test_tool", |_| Ok("success".to_string()));
553        assert!(orchestrator.registered_tools().contains(&"test_tool"));
554    }
555
556    #[test]
557    fn test_simple_script() {
558        let orchestrator = ToolOrchestrator::new();
559        let result = orchestrator
560            .execute("let x = 1 + 2; x", ExecutionLimits::default())
561            .unwrap();
562        assert!(result.success);
563        assert_eq!(result.output, "3");
564    }
565
566    #[test]
567    fn test_string_interpolation() {
568        let orchestrator = ToolOrchestrator::new();
569        let result = orchestrator
570            .execute(
571                r#"let name = "world"; `Hello, ${name}!`"#,
572                ExecutionLimits::default(),
573            )
574            .unwrap();
575        assert!(result.success);
576        assert_eq!(result.output, "Hello, world!");
577    }
578
579    #[test]
580    fn test_tool_execution() {
581        let mut orchestrator = ToolOrchestrator::new();
582        orchestrator.register_executor("greet", |input| {
583            let name = input.as_str().unwrap_or("stranger");
584            Ok(format!("Hello, {}!", name))
585        });
586
587        let result = orchestrator
588            .execute(r#"greet("Claude")"#, ExecutionLimits::default())
589            .unwrap();
590
591        assert!(result.success);
592        assert_eq!(result.output, "Hello, Claude!");
593        assert_eq!(result.tool_calls.len(), 1);
594        assert_eq!(result.tool_calls[0].tool_name, "greet");
595    }
596
597    #[test]
598    fn test_max_operations_limit() {
599        let orchestrator = ToolOrchestrator::new();
600        let limits = ExecutionLimits::default().with_max_operations(10);
601
602        // This should exceed the operations limit
603        let result = orchestrator.execute(
604            "let sum = 0; for i in 0..1000 { sum += i; } sum",
605            limits,
606        );
607
608        assert!(matches!(
609            result,
610            Err(OrchestratorError::MaxOperationsExceeded(_))
611        ));
612    }
613
614    #[test]
615    fn test_compilation_error() {
616        let orchestrator = ToolOrchestrator::new();
617        let result = orchestrator.execute(
618            "this is not valid rhai syntax {{{{",
619            ExecutionLimits::default(),
620        );
621
622        assert!(matches!(result, Err(OrchestratorError::CompilationError(_))));
623    }
624
625    #[test]
626    fn test_multiple_tool_calls() {
627        let mut orchestrator = ToolOrchestrator::new();
628
629        orchestrator.register_executor("add", |input| {
630            if let Some(arr) = input.as_array() {
631                let sum: i64 = arr.iter().filter_map(|v| v.as_i64()).sum();
632                Ok(sum.to_string())
633            } else {
634                Err("Expected array".to_string())
635            }
636        });
637
638        let script = r#"
639            let a = add([1, 2, 3]);
640            let b = add([4, 5, 6]);
641            `Sum1: ${a}, Sum2: ${b}`
642        "#;
643
644        let result = orchestrator
645            .execute(script, ExecutionLimits::default())
646            .unwrap();
647
648        assert!(result.success);
649        assert_eq!(result.tool_calls.len(), 2);
650        assert!(result.output.contains("Sum1: 6"));
651        assert!(result.output.contains("Sum2: 15"));
652    }
653
654    #[test]
655    fn test_tool_error_handling() {
656        let mut orchestrator = ToolOrchestrator::new();
657        orchestrator.register_executor("fail_tool", |_| Err("Intentional failure".to_string()));
658
659        let result = orchestrator
660            .execute(r#"fail_tool("test")"#, ExecutionLimits::default())
661            .unwrap();
662
663        assert!(result.success); // Script completes, tool error is in output
664        assert!(result.output.contains("Tool error"));
665        assert_eq!(result.tool_calls.len(), 1);
666        assert!(!result.tool_calls[0].success);
667    }
668
669    #[test]
670    fn test_max_tool_calls_limit() {
671        let mut orchestrator = ToolOrchestrator::new();
672        orchestrator.register_executor("count", |_| Ok("1".to_string()));
673
674        let limits = ExecutionLimits::default().with_max_tool_calls(3);
675        // Return the 4th call result directly so we can see the error
676        let script = r#"
677            let a = count("1");
678            let b = count("2");
679            let c = count("3");
680            count("4")
681        "#;
682
683        let result = orchestrator.execute(script, limits).unwrap();
684
685        // Fourth call should return error message instead of executing
686        assert!(
687            result.output.contains("Maximum tool calls"),
688            "Expected error message about max tool calls, got: {}",
689            result.output
690        );
691        // Only 3 calls should be recorded (the 4th was blocked)
692        assert_eq!(result.tool_calls.len(), 3);
693    }
694
695    #[test]
696    fn test_tool_with_map_input() {
697        let mut orchestrator = ToolOrchestrator::new();
698        orchestrator.register_executor("get_value", |input| {
699            if let Some(obj) = input.as_object() {
700                if let Some(key) = obj.get("key").and_then(|v| v.as_str()) {
701                    Ok(format!("Got key: {}", key))
702                } else {
703                    Err("Missing key field".to_string())
704                }
705            } else {
706                Err("Expected object".to_string())
707            }
708        });
709
710        let result = orchestrator
711            .execute(r#"get_value(#{ key: "test_key" })"#, ExecutionLimits::default())
712            .unwrap();
713
714        assert!(result.success);
715        assert_eq!(result.output, "Got key: test_key");
716    }
717
718    #[test]
719    fn test_loop_with_tool_calls() {
720        let mut orchestrator = ToolOrchestrator::new();
721        orchestrator.register_executor("double", |input| {
722            let n = input.as_i64().unwrap_or(0);
723            Ok((n * 2).to_string())
724        });
725
726        let script = r#"
727            let results = [];
728            for i in 1..4 {
729                results.push(double(i));
730            }
731            results
732        "#;
733
734        let result = orchestrator
735            .execute(script, ExecutionLimits::default())
736            .unwrap();
737
738        assert!(result.success);
739        assert_eq!(result.tool_calls.len(), 3);
740    }
741
742    #[test]
743    fn test_conditional_tool_calls() {
744        let mut orchestrator = ToolOrchestrator::new();
745        orchestrator.register_executor("check", |input| {
746            let n = input.as_i64().unwrap_or(0);
747            Ok(if n > 5 { "big" } else { "small" }.to_string())
748        });
749
750        let script = r#"
751            let x = 10;
752            if x > 5 {
753                check(x)
754            } else {
755                "skipped"
756            }
757        "#;
758
759        let result = orchestrator
760            .execute(script, ExecutionLimits::default())
761            .unwrap();
762
763        assert!(result.success);
764        assert_eq!(result.output, "big");
765        assert_eq!(result.tool_calls.len(), 1);
766    }
767
768    #[test]
769    fn test_empty_script() {
770        let orchestrator = ToolOrchestrator::new();
771        let result = orchestrator
772            .execute("", ExecutionLimits::default())
773            .unwrap();
774
775        assert!(result.success);
776        assert!(result.output.is_empty());
777    }
778
779    #[test]
780    fn test_unit_return() {
781        let orchestrator = ToolOrchestrator::new();
782        let result = orchestrator
783            .execute("let x = 5;", ExecutionLimits::default())
784            .unwrap();
785
786        assert!(result.success);
787        assert!(result.output.is_empty()); // Unit type returns empty string
788    }
789
790    #[test]
791    fn test_dynamic_to_json_types() {
792        // Test various Rhai Dynamic types convert to JSON correctly
793        use rhai::Dynamic;
794
795        // String
796        let d = Dynamic::from("hello".to_string());
797        let j = dynamic_to_json(&d);
798        assert_eq!(j, serde_json::json!("hello"));
799
800        // Integer
801        let d = Dynamic::from(42_i64);
802        let j = dynamic_to_json(&d);
803        assert_eq!(j, serde_json::json!(42));
804
805        // Float
806        let d = Dynamic::from(3.14_f64);
807        let j = dynamic_to_json(&d);
808        assert!(j.as_f64().unwrap() - 3.14 < 0.001);
809
810        // Boolean
811        let d = Dynamic::from(true);
812        let j = dynamic_to_json(&d);
813        assert_eq!(j, serde_json::json!(true));
814
815        // Unit (null)
816        let d = Dynamic::UNIT;
817        let j = dynamic_to_json(&d);
818        assert_eq!(j, serde_json::Value::Null);
819    }
820
821    #[test]
822    fn test_execution_time_recorded() {
823        let orchestrator = ToolOrchestrator::new();
824        let result = orchestrator
825            .execute("let sum = 0; for i in 0..100 { sum += i; } sum", ExecutionLimits::default())
826            .unwrap();
827
828        assert!(result.success);
829        // execution_time_ms is always recorded (u64 is always >= 0, but we verify a result exists)
830        assert!(result.execution_time_ms < 10000); // Should complete in under 10 seconds
831    }
832
833    #[test]
834    fn test_tool_call_duration_recorded() {
835        let mut orchestrator = ToolOrchestrator::new();
836        orchestrator.register_executor("slow_tool", |_| {
837            std::thread::sleep(std::time::Duration::from_millis(10));
838            Ok("done".to_string())
839        });
840
841        let result = orchestrator
842            .execute(r#"slow_tool("test")"#, ExecutionLimits::default())
843            .unwrap();
844
845        assert!(result.success);
846        assert_eq!(result.tool_calls.len(), 1);
847        assert!(result.tool_calls[0].duration_ms >= 10);
848    }
849
850    #[test]
851    fn test_default_impl() {
852        // Test that Default::default() works for ToolOrchestrator
853        let orchestrator = ToolOrchestrator::default();
854        assert!(orchestrator.registered_tools().is_empty());
855
856        // Execute a simple script to verify it works
857        let result = orchestrator
858            .execute("1 + 1", ExecutionLimits::default())
859            .unwrap();
860        assert!(result.success);
861        assert_eq!(result.output, "2");
862    }
863
864    #[test]
865    fn test_timeout_error() {
866        let orchestrator = ToolOrchestrator::new();
867
868        // Use a CPU-intensive loop that will trigger on_progress checks
869        // Set timeout to 1ms - the loop will exceed this quickly
870        let limits = ExecutionLimits::default()
871            .with_timeout_ms(1)
872            .with_max_operations(1_000_000); // Allow many ops so timeout triggers first
873
874        // This loop will keep running until timeout kicks in via on_progress
875        let result = orchestrator.execute(
876            r#"
877            let sum = 0;
878            for i in 0..1000000 {
879                sum += i;
880            }
881            sum
882            "#,
883            limits,
884        );
885
886        // Should return a timeout error (real-time via on_progress)
887        assert!(result.is_err());
888        match result {
889            Err(OrchestratorError::Timeout(ms)) => assert_eq!(ms, 1),
890            _ => panic!("Expected Timeout error, got: {:?}", result),
891        }
892    }
893
894    #[test]
895    fn test_runtime_error() {
896        let orchestrator = ToolOrchestrator::new();
897
898        // This should cause a runtime error (undefined variable)
899        let result = orchestrator.execute("undefined_variable", ExecutionLimits::default());
900
901        assert!(result.is_err());
902        match result {
903            Err(OrchestratorError::ExecutionError(msg)) => {
904                assert!(msg.contains("undefined_variable") || msg.contains("not found"));
905            }
906            _ => panic!("Expected ExecutionError"),
907        }
908    }
909
910    #[test]
911    fn test_registered_tools() {
912        let mut orchestrator = ToolOrchestrator::new();
913        assert!(orchestrator.registered_tools().is_empty());
914
915        orchestrator.register_executor("tool_a", |_| Ok("a".to_string()));
916        orchestrator.register_executor("tool_b", |_| Ok("b".to_string()));
917
918        let tools = orchestrator.registered_tools();
919        assert_eq!(tools.len(), 2);
920        assert!(tools.contains(&"tool_a"));
921        assert!(tools.contains(&"tool_b"));
922    }
923
924    #[test]
925    fn test_dynamic_to_json_array() {
926        use rhai::Dynamic;
927
928        // Create an array
929        let arr: Vec<Dynamic> = vec![
930            Dynamic::from(1_i64),
931            Dynamic::from(2_i64),
932            Dynamic::from(3_i64),
933        ];
934        let d = Dynamic::from(arr);
935        let j = dynamic_to_json(&d);
936
937        assert_eq!(j, serde_json::json!([1, 2, 3]));
938    }
939
940    #[test]
941    fn test_dynamic_to_json_map() {
942        use rhai::{Dynamic, Map};
943
944        // Create a map
945        let mut map = Map::new();
946        map.insert("key".into(), Dynamic::from("value".to_string()));
947        map.insert("num".into(), Dynamic::from(42_i64));
948        let d = Dynamic::from(map);
949        let j = dynamic_to_json(&d);
950
951        assert!(j.is_object());
952        let obj = j.as_object().unwrap();
953        assert_eq!(obj.get("key").unwrap(), &serde_json::json!("value"));
954        assert_eq!(obj.get("num").unwrap(), &serde_json::json!(42));
955    }
956
957    #[test]
958    fn test_non_string_result() {
959        // Test that non-string results are formatted with Debug
960        let orchestrator = ToolOrchestrator::new();
961
962        // Return an integer (not a string)
963        let result = orchestrator
964            .execute("42", ExecutionLimits::default())
965            .unwrap();
966
967        assert!(result.success);
968        assert_eq!(result.output, "42");
969    }
970
971    #[test]
972    fn test_array_result() {
973        // Test that array results are formatted
974        let orchestrator = ToolOrchestrator::new();
975
976        let result = orchestrator
977            .execute("[1, 2, 3]", ExecutionLimits::default())
978            .unwrap();
979
980        assert!(result.success);
981        // Arrays are formatted with Debug
982        assert!(result.output.contains("1"));
983        assert!(result.output.contains("2"));
984        assert!(result.output.contains("3"));
985    }
986
987    #[test]
988    fn test_dynamic_to_json_fallback() {
989        use rhai::Dynamic;
990
991        // Create a custom type that doesn't match standard types
992        // Using a timestamp (FnPtr or similar) that falls through to the else branch
993        #[derive(Clone)]
994        struct CustomType {
995            #[allow(dead_code)]
996            value: i32,
997        }
998
999        let custom = CustomType { value: 42 };
1000        let d = Dynamic::from(custom);
1001        let j = dynamic_to_json(&d);
1002
1003        // Should fall back to string representation via Debug
1004        assert!(j.is_string());
1005        // The string should contain some representation of the type
1006        let s = j.as_str().unwrap();
1007        assert!(!s.is_empty());
1008    }
1009}