1use std::collections::HashMap;
53
54#[cfg(feature = "native")]
55use std::sync::{Arc, Mutex};
56#[cfg(feature = "native")]
57use std::time::Instant;
58
59#[cfg(feature = "wasm")]
60use std::cell::RefCell;
61#[cfg(feature = "wasm")]
62use std::rc::Rc;
63#[cfg(feature = "wasm")]
64use web_time::Instant;
65
66use rhai::{Engine, EvalAltResult, Scope};
67
68use crate::sandbox::ExecutionLimits;
69use crate::types::{OrchestratorError, OrchestratorResult, ToolCall};
70
71const MAX_EXPR_DEPTH: usize = 64;
77
78const MAX_CALL_DEPTH: usize = 64;
80
81#[cfg(feature = "native")]
87pub type SharedVec<T> = Arc<Mutex<Vec<T>>>;
88
89#[cfg(feature = "native")]
91pub type SharedCounter = Arc<Mutex<usize>>;
92
93#[cfg(feature = "native")]
106pub type ToolExecutor = Arc<dyn Fn(serde_json::Value) -> Result<String, String> + Send + Sync>;
107
108#[cfg(feature = "wasm")]
110pub type SharedVec<T> = Rc<RefCell<Vec<T>>>;
111
112#[cfg(feature = "wasm")]
114pub type SharedCounter = Rc<RefCell<usize>>;
115
116#[cfg(feature = "wasm")]
120pub type ToolExecutor = Rc<dyn Fn(serde_json::Value) -> Result<String, String>>;
121
122#[cfg(feature = "native")]
131fn new_shared_vec<T>() -> SharedVec<T> {
132 Arc::new(Mutex::new(Vec::new()))
133}
134
135#[cfg(feature = "wasm")]
136fn new_shared_vec<T>() -> SharedVec<T> {
137 Rc::new(RefCell::new(Vec::new()))
138}
139
140#[cfg(feature = "native")]
141fn new_shared_counter() -> SharedCounter {
142 Arc::new(Mutex::new(0))
143}
144
145#[cfg(feature = "wasm")]
146fn new_shared_counter() -> SharedCounter {
147 Rc::new(RefCell::new(0))
148}
149
150#[cfg(feature = "native")]
151fn clone_shared<T: ?Sized>(shared: &Arc<T>) -> Arc<T> {
152 Arc::clone(shared)
153}
154
155#[cfg(feature = "wasm")]
156fn clone_shared<T: ?Sized>(shared: &Rc<T>) -> Rc<T> {
157 Rc::clone(shared)
158}
159
160#[cfg(feature = "native")]
161fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
162 shared.lock().unwrap().clone()
163}
164
165#[cfg(feature = "wasm")]
166fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
167 shared.borrow().clone()
168}
169
170#[cfg(feature = "native")]
171fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
172 shared.lock().unwrap().push(item);
173}
174
175#[cfg(feature = "wasm")]
176fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
177 shared.borrow_mut().push(item);
178}
179
180#[cfg(feature = "native")]
181fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
182 let mut c = shared.lock().unwrap();
183 if *c >= max {
184 return Err(());
185 }
186 *c += 1;
187 drop(c); Ok(())
189}
190
191#[cfg(feature = "wasm")]
192fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
193 let mut c = shared.borrow_mut();
194 if *c >= max {
195 return Err(());
196 }
197 *c += 1;
198 Ok(())
199}
200
201pub struct ToolOrchestrator {
251 #[allow(dead_code)]
252 engine: Engine,
253 executors: HashMap<String, ToolExecutor>,
254}
255
256impl ToolOrchestrator {
257 #[must_use]
262 pub fn new() -> Self {
263 let mut engine = Engine::new();
264
265 engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
267
268 Self {
269 engine,
270 executors: HashMap::new(),
271 }
272 }
273
274 #[cfg(feature = "native")]
294 pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
295 where
296 F: Fn(serde_json::Value) -> Result<String, String> + Send + Sync + 'static,
297 {
298 self.executors.insert(name.into(), Arc::new(executor));
299 }
300
301 #[cfg(feature = "wasm")]
305 pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
306 where
307 F: Fn(serde_json::Value) -> Result<String, String> + 'static,
308 {
309 self.executors.insert(name.into(), Rc::new(executor));
310 }
311
312 pub fn execute(
343 &self,
344 script: &str,
345 limits: ExecutionLimits,
346 ) -> Result<OrchestratorResult, OrchestratorError> {
347 let start_time = Instant::now();
348 let tool_calls: SharedVec<ToolCall> = new_shared_vec();
349 let call_count: SharedCounter = new_shared_counter();
350
351 let mut engine = Engine::new();
353
354 engine.set_max_operations(limits.max_operations);
356 engine.set_max_string_size(limits.max_string_size);
357 engine.set_max_array_size(limits.max_array_size);
358 engine.set_max_map_size(limits.max_map_size);
359 engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
360
361 let timeout_ms = limits.timeout_ms;
363 let progress_start = Instant::now();
364 engine.on_progress(move |_ops| {
365 let elapsed = u64::try_from(progress_start.elapsed().as_millis()).unwrap_or(u64::MAX);
367 if elapsed > timeout_ms {
368 Some(rhai::Dynamic::from("timeout"))
369 } else {
370 None
371 }
372 });
373
374 for (name, executor) in &self.executors {
376 let exec = clone_shared(executor);
377 let calls = clone_shared(&tool_calls);
378 let count = clone_shared(&call_count);
379 let max_calls = limits.max_tool_calls;
380 let tool_name = name.clone();
381
382 engine.register_fn(name.as_str(), move |input: rhai::Dynamic| -> String {
384 let call_start = Instant::now();
385
386 if increment_counter(&count, max_calls).is_err() {
388 return format!("ERROR: Maximum tool calls ({max_calls}) exceeded");
389 }
390
391 let json_input = dynamic_to_json(&input);
393
394 let (output, success) = match exec(json_input.clone()) {
396 Ok(result) => (result, true),
397 Err(e) => (format!("Tool error: {e}"), false),
398 };
399
400 let duration_ms = u64::try_from(call_start.elapsed().as_millis()).unwrap_or(u64::MAX);
402 let call = ToolCall::new(
403 tool_name.clone(),
404 json_input,
405 output.clone(),
406 success,
407 duration_ms,
408 );
409 push_to_vec(&calls, call);
410
411 output
412 });
413 }
414
415 let ast = engine
417 .compile(script)
418 .map_err(|e| OrchestratorError::CompilationError(e.to_string()))?;
419
420 let mut scope = Scope::new();
422 let result = engine
423 .eval_ast_with_scope::<rhai::Dynamic>(&mut scope, &ast)
424 .map_err(|e| match *e {
425 EvalAltResult::ErrorTooManyOperations(_) => {
426 OrchestratorError::MaxOperationsExceeded(limits.max_operations)
427 }
428 EvalAltResult::ErrorTerminated(_, _) => {
429 OrchestratorError::Timeout(limits.timeout_ms)
430 }
431 _ => OrchestratorError::ExecutionError(e.to_string()),
432 })?;
433
434 let execution_time_ms = u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
435
436 let output = if result.is_string() {
438 result.into_string().unwrap_or_default()
439 } else if result.is_unit() {
440 String::new()
441 } else {
442 format!("{result:?}")
443 };
444
445 let calls = lock_vec(&tool_calls);
446 Ok(OrchestratorResult::success(output, calls, execution_time_ms))
447 }
448
449 #[must_use]
468 pub fn registered_tools(&self) -> Vec<&str> {
469 self.executors.keys().map(String::as_str).collect()
470 }
471}
472
473impl Default for ToolOrchestrator {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479pub fn dynamic_to_json(value: &rhai::Dynamic) -> serde_json::Value {
510 if value.is_string() {
511 serde_json::Value::String(value.clone().into_string().unwrap_or_default())
512 } else if value.is_int() {
513 serde_json::Value::Number(serde_json::Number::from(value.clone().as_int().unwrap_or(0)))
514 } else if value.is_float() {
515 serde_json::json!(value.clone().as_float().unwrap_or(0.0))
516 } else if value.is_bool() {
517 serde_json::Value::Bool(value.clone().as_bool().unwrap_or(false))
518 } else if value.is_array() {
519 let arr: Vec<rhai::Dynamic> = value.clone().into_array().unwrap_or_default();
520 serde_json::Value::Array(arr.iter().map(dynamic_to_json).collect())
521 } else if value.is_map() {
522 let map: rhai::Map = value.clone().cast();
523 let mut json_map = serde_json::Map::new();
524 for (k, v) in &map {
525 json_map.insert(k.to_string(), dynamic_to_json(v));
526 }
527 serde_json::Value::Object(json_map)
528 } else if value.is_unit() {
529 serde_json::Value::Null
530 } else {
531 serde_json::Value::String(format!("{value:?}"))
532 }
533}
534
535#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_orchestrator_creation() {
545 let orchestrator = ToolOrchestrator::new();
546 assert!(orchestrator.registered_tools().is_empty());
547 }
548
549 #[test]
550 fn test_register_executor() {
551 let mut orchestrator = ToolOrchestrator::new();
552 orchestrator.register_executor("test_tool", |_| Ok("success".to_string()));
553 assert!(orchestrator.registered_tools().contains(&"test_tool"));
554 }
555
556 #[test]
557 fn test_simple_script() {
558 let orchestrator = ToolOrchestrator::new();
559 let result = orchestrator
560 .execute("let x = 1 + 2; x", ExecutionLimits::default())
561 .unwrap();
562 assert!(result.success);
563 assert_eq!(result.output, "3");
564 }
565
566 #[test]
567 fn test_string_interpolation() {
568 let orchestrator = ToolOrchestrator::new();
569 let result = orchestrator
570 .execute(
571 r#"let name = "world"; `Hello, ${name}!`"#,
572 ExecutionLimits::default(),
573 )
574 .unwrap();
575 assert!(result.success);
576 assert_eq!(result.output, "Hello, world!");
577 }
578
579 #[test]
580 fn test_tool_execution() {
581 let mut orchestrator = ToolOrchestrator::new();
582 orchestrator.register_executor("greet", |input| {
583 let name = input.as_str().unwrap_or("stranger");
584 Ok(format!("Hello, {}!", name))
585 });
586
587 let result = orchestrator
588 .execute(r#"greet("Claude")"#, ExecutionLimits::default())
589 .unwrap();
590
591 assert!(result.success);
592 assert_eq!(result.output, "Hello, Claude!");
593 assert_eq!(result.tool_calls.len(), 1);
594 assert_eq!(result.tool_calls[0].tool_name, "greet");
595 }
596
597 #[test]
598 fn test_max_operations_limit() {
599 let orchestrator = ToolOrchestrator::new();
600 let limits = ExecutionLimits::default().with_max_operations(10);
601
602 let result = orchestrator.execute(
604 "let sum = 0; for i in 0..1000 { sum += i; } sum",
605 limits,
606 );
607
608 assert!(matches!(
609 result,
610 Err(OrchestratorError::MaxOperationsExceeded(_))
611 ));
612 }
613
614 #[test]
615 fn test_compilation_error() {
616 let orchestrator = ToolOrchestrator::new();
617 let result = orchestrator.execute(
618 "this is not valid rhai syntax {{{{",
619 ExecutionLimits::default(),
620 );
621
622 assert!(matches!(result, Err(OrchestratorError::CompilationError(_))));
623 }
624
625 #[test]
626 fn test_multiple_tool_calls() {
627 let mut orchestrator = ToolOrchestrator::new();
628
629 orchestrator.register_executor("add", |input| {
630 if let Some(arr) = input.as_array() {
631 let sum: i64 = arr.iter().filter_map(|v| v.as_i64()).sum();
632 Ok(sum.to_string())
633 } else {
634 Err("Expected array".to_string())
635 }
636 });
637
638 let script = r#"
639 let a = add([1, 2, 3]);
640 let b = add([4, 5, 6]);
641 `Sum1: ${a}, Sum2: ${b}`
642 "#;
643
644 let result = orchestrator
645 .execute(script, ExecutionLimits::default())
646 .unwrap();
647
648 assert!(result.success);
649 assert_eq!(result.tool_calls.len(), 2);
650 assert!(result.output.contains("Sum1: 6"));
651 assert!(result.output.contains("Sum2: 15"));
652 }
653
654 #[test]
655 fn test_tool_error_handling() {
656 let mut orchestrator = ToolOrchestrator::new();
657 orchestrator.register_executor("fail_tool", |_| Err("Intentional failure".to_string()));
658
659 let result = orchestrator
660 .execute(r#"fail_tool("test")"#, ExecutionLimits::default())
661 .unwrap();
662
663 assert!(result.success); assert!(result.output.contains("Tool error"));
665 assert_eq!(result.tool_calls.len(), 1);
666 assert!(!result.tool_calls[0].success);
667 }
668
669 #[test]
670 fn test_max_tool_calls_limit() {
671 let mut orchestrator = ToolOrchestrator::new();
672 orchestrator.register_executor("count", |_| Ok("1".to_string()));
673
674 let limits = ExecutionLimits::default().with_max_tool_calls(3);
675 let script = r#"
677 let a = count("1");
678 let b = count("2");
679 let c = count("3");
680 count("4")
681 "#;
682
683 let result = orchestrator.execute(script, limits).unwrap();
684
685 assert!(
687 result.output.contains("Maximum tool calls"),
688 "Expected error message about max tool calls, got: {}",
689 result.output
690 );
691 assert_eq!(result.tool_calls.len(), 3);
693 }
694
695 #[test]
696 fn test_tool_with_map_input() {
697 let mut orchestrator = ToolOrchestrator::new();
698 orchestrator.register_executor("get_value", |input| {
699 if let Some(obj) = input.as_object() {
700 if let Some(key) = obj.get("key").and_then(|v| v.as_str()) {
701 Ok(format!("Got key: {}", key))
702 } else {
703 Err("Missing key field".to_string())
704 }
705 } else {
706 Err("Expected object".to_string())
707 }
708 });
709
710 let result = orchestrator
711 .execute(r#"get_value(#{ key: "test_key" })"#, ExecutionLimits::default())
712 .unwrap();
713
714 assert!(result.success);
715 assert_eq!(result.output, "Got key: test_key");
716 }
717
718 #[test]
719 fn test_loop_with_tool_calls() {
720 let mut orchestrator = ToolOrchestrator::new();
721 orchestrator.register_executor("double", |input| {
722 let n = input.as_i64().unwrap_or(0);
723 Ok((n * 2).to_string())
724 });
725
726 let script = r#"
727 let results = [];
728 for i in 1..4 {
729 results.push(double(i));
730 }
731 results
732 "#;
733
734 let result = orchestrator
735 .execute(script, ExecutionLimits::default())
736 .unwrap();
737
738 assert!(result.success);
739 assert_eq!(result.tool_calls.len(), 3);
740 }
741
742 #[test]
743 fn test_conditional_tool_calls() {
744 let mut orchestrator = ToolOrchestrator::new();
745 orchestrator.register_executor("check", |input| {
746 let n = input.as_i64().unwrap_or(0);
747 Ok(if n > 5 { "big" } else { "small" }.to_string())
748 });
749
750 let script = r#"
751 let x = 10;
752 if x > 5 {
753 check(x)
754 } else {
755 "skipped"
756 }
757 "#;
758
759 let result = orchestrator
760 .execute(script, ExecutionLimits::default())
761 .unwrap();
762
763 assert!(result.success);
764 assert_eq!(result.output, "big");
765 assert_eq!(result.tool_calls.len(), 1);
766 }
767
768 #[test]
769 fn test_empty_script() {
770 let orchestrator = ToolOrchestrator::new();
771 let result = orchestrator
772 .execute("", ExecutionLimits::default())
773 .unwrap();
774
775 assert!(result.success);
776 assert!(result.output.is_empty());
777 }
778
779 #[test]
780 fn test_unit_return() {
781 let orchestrator = ToolOrchestrator::new();
782 let result = orchestrator
783 .execute("let x = 5;", ExecutionLimits::default())
784 .unwrap();
785
786 assert!(result.success);
787 assert!(result.output.is_empty()); }
789
790 #[test]
791 fn test_dynamic_to_json_types() {
792 use rhai::Dynamic;
794
795 let d = Dynamic::from("hello".to_string());
797 let j = dynamic_to_json(&d);
798 assert_eq!(j, serde_json::json!("hello"));
799
800 let d = Dynamic::from(42_i64);
802 let j = dynamic_to_json(&d);
803 assert_eq!(j, serde_json::json!(42));
804
805 let d = Dynamic::from(3.14_f64);
807 let j = dynamic_to_json(&d);
808 assert!(j.as_f64().unwrap() - 3.14 < 0.001);
809
810 let d = Dynamic::from(true);
812 let j = dynamic_to_json(&d);
813 assert_eq!(j, serde_json::json!(true));
814
815 let d = Dynamic::UNIT;
817 let j = dynamic_to_json(&d);
818 assert_eq!(j, serde_json::Value::Null);
819 }
820
821 #[test]
822 fn test_execution_time_recorded() {
823 let orchestrator = ToolOrchestrator::new();
824 let result = orchestrator
825 .execute("let sum = 0; for i in 0..100 { sum += i; } sum", ExecutionLimits::default())
826 .unwrap();
827
828 assert!(result.success);
829 assert!(result.execution_time_ms < 10000); }
832
833 #[test]
834 fn test_tool_call_duration_recorded() {
835 let mut orchestrator = ToolOrchestrator::new();
836 orchestrator.register_executor("slow_tool", |_| {
837 std::thread::sleep(std::time::Duration::from_millis(10));
838 Ok("done".to_string())
839 });
840
841 let result = orchestrator
842 .execute(r#"slow_tool("test")"#, ExecutionLimits::default())
843 .unwrap();
844
845 assert!(result.success);
846 assert_eq!(result.tool_calls.len(), 1);
847 assert!(result.tool_calls[0].duration_ms >= 10);
848 }
849
850 #[test]
851 fn test_default_impl() {
852 let orchestrator = ToolOrchestrator::default();
854 assert!(orchestrator.registered_tools().is_empty());
855
856 let result = orchestrator
858 .execute("1 + 1", ExecutionLimits::default())
859 .unwrap();
860 assert!(result.success);
861 assert_eq!(result.output, "2");
862 }
863
864 #[test]
865 fn test_timeout_error() {
866 let orchestrator = ToolOrchestrator::new();
867
868 let limits = ExecutionLimits::default()
871 .with_timeout_ms(1)
872 .with_max_operations(1_000_000); let result = orchestrator.execute(
876 r#"
877 let sum = 0;
878 for i in 0..1000000 {
879 sum += i;
880 }
881 sum
882 "#,
883 limits,
884 );
885
886 assert!(result.is_err());
888 match result {
889 Err(OrchestratorError::Timeout(ms)) => assert_eq!(ms, 1),
890 _ => panic!("Expected Timeout error, got: {:?}", result),
891 }
892 }
893
894 #[test]
895 fn test_runtime_error() {
896 let orchestrator = ToolOrchestrator::new();
897
898 let result = orchestrator.execute("undefined_variable", ExecutionLimits::default());
900
901 assert!(result.is_err());
902 match result {
903 Err(OrchestratorError::ExecutionError(msg)) => {
904 assert!(msg.contains("undefined_variable") || msg.contains("not found"));
905 }
906 _ => panic!("Expected ExecutionError"),
907 }
908 }
909
910 #[test]
911 fn test_registered_tools() {
912 let mut orchestrator = ToolOrchestrator::new();
913 assert!(orchestrator.registered_tools().is_empty());
914
915 orchestrator.register_executor("tool_a", |_| Ok("a".to_string()));
916 orchestrator.register_executor("tool_b", |_| Ok("b".to_string()));
917
918 let tools = orchestrator.registered_tools();
919 assert_eq!(tools.len(), 2);
920 assert!(tools.contains(&"tool_a"));
921 assert!(tools.contains(&"tool_b"));
922 }
923
924 #[test]
925 fn test_dynamic_to_json_array() {
926 use rhai::Dynamic;
927
928 let arr: Vec<Dynamic> = vec![
930 Dynamic::from(1_i64),
931 Dynamic::from(2_i64),
932 Dynamic::from(3_i64),
933 ];
934 let d = Dynamic::from(arr);
935 let j = dynamic_to_json(&d);
936
937 assert_eq!(j, serde_json::json!([1, 2, 3]));
938 }
939
940 #[test]
941 fn test_dynamic_to_json_map() {
942 use rhai::{Dynamic, Map};
943
944 let mut map = Map::new();
946 map.insert("key".into(), Dynamic::from("value".to_string()));
947 map.insert("num".into(), Dynamic::from(42_i64));
948 let d = Dynamic::from(map);
949 let j = dynamic_to_json(&d);
950
951 assert!(j.is_object());
952 let obj = j.as_object().unwrap();
953 assert_eq!(obj.get("key").unwrap(), &serde_json::json!("value"));
954 assert_eq!(obj.get("num").unwrap(), &serde_json::json!(42));
955 }
956
957 #[test]
958 fn test_non_string_result() {
959 let orchestrator = ToolOrchestrator::new();
961
962 let result = orchestrator
964 .execute("42", ExecutionLimits::default())
965 .unwrap();
966
967 assert!(result.success);
968 assert_eq!(result.output, "42");
969 }
970
971 #[test]
972 fn test_array_result() {
973 let orchestrator = ToolOrchestrator::new();
975
976 let result = orchestrator
977 .execute("[1, 2, 3]", ExecutionLimits::default())
978 .unwrap();
979
980 assert!(result.success);
981 assert!(result.output.contains("1"));
983 assert!(result.output.contains("2"));
984 assert!(result.output.contains("3"));
985 }
986
987 #[test]
988 fn test_dynamic_to_json_fallback() {
989 use rhai::Dynamic;
990
991 #[derive(Clone)]
994 struct CustomType {
995 #[allow(dead_code)]
996 value: i32,
997 }
998
999 let custom = CustomType { value: 42 };
1000 let d = Dynamic::from(custom);
1001 let j = dynamic_to_json(&d);
1002
1003 assert!(j.is_string());
1005 let s = j.as_str().unwrap();
1007 assert!(!s.is_empty());
1008 }
1009}