1use std::{
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use tokio::{task::JoinSet, time::timeout};
11
12use super::cache::{CacheKey, ToolCallCache};
13use crate::{
14 model::{
15 chat_base_response::ToolCallMessage,
16 chat_message_types::TextMessage,
17 tools::{Function, Tools},
18 },
19 toolkits::{
20 core::DynTool,
21 error::{ToolResult, error_context},
22 },
23};
24
25type ToolHandler = std::sync::Arc<
27 dyn Fn(
28 serde_json::Value,
29 ) -> std::pin::Pin<
30 Box<
31 dyn std::future::Future<
32 Output = crate::toolkits::error::ToolResult<serde_json::Value>,
33 > + Send,
34 >,
35 > + Send
36 + Sync,
37>;
38
39#[derive(Debug, Clone)]
41pub struct RetryConfig {
42 pub max_retries: u32,
43 pub initial_delay: Duration,
44 pub max_delay: Duration,
45 pub backoff_multiplier: f64,
46}
47
48impl Default for RetryConfig {
49 fn default() -> Self {
50 Self {
51 max_retries: 3,
52 initial_delay: Duration::from_millis(100),
53 max_delay: Duration::from_secs(30),
54 backoff_multiplier: 2.0,
55 }
56 }
57}
58
59impl RetryConfig {
60 pub fn calculate_delay(&self, attempt: u32) -> Duration {
61 if attempt == 0 {
62 return Duration::ZERO;
63 }
64
65 let delay_ms = self.initial_delay.as_millis() as f64
66 * self.backoff_multiplier.powi((attempt - 1) as i32);
67 let delay_ms = delay_ms.min(self.max_delay.as_millis() as f64) as u64;
68
69 Duration::from_millis(delay_ms)
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ExecutionConfig {
76 pub timeout: Option<Duration>,
77 pub retry_config: RetryConfig,
78 pub validate_parameters: bool,
79 pub enable_logging: bool,
80}
81
82impl Default for ExecutionConfig {
83 fn default() -> Self {
84 Self {
85 timeout: Some(Duration::from_secs(30)),
86 retry_config: RetryConfig::default(),
87 validate_parameters: true,
88 enable_logging: false,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ExecutionResult {
96 pub tool_name: String,
97 pub result: serde_json::Value,
98 pub duration: Duration,
99 pub success: bool,
100 pub error: Option<String>,
101 pub retries: u32,
102 pub timestamp: std::time::SystemTime,
103 pub metadata: std::collections::HashMap<String, serde_json::Value>,
104}
105
106impl ExecutionResult {
107 pub fn success(
108 tool_name: String,
109 result: serde_json::Value,
110 duration: Duration,
111 retries: u32,
112 ) -> Self {
113 Self {
114 tool_name,
115 result,
116 duration,
117 success: true,
118 error: None,
119 retries,
120 timestamp: std::time::SystemTime::now(),
121 metadata: std::collections::HashMap::new(),
122 }
123 }
124
125 pub fn failure(tool_name: String, error: String, duration: Duration, retries: u32) -> Self {
126 Self {
127 tool_name,
128 result: serde_json::Value::Null,
129 duration,
130 success: false,
131 error: Some(error),
132 retries,
133 timestamp: std::time::SystemTime::now(),
134 metadata: std::collections::HashMap::new(),
135 }
136 }
137
138 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
139 self.metadata.insert(key.into(), value);
140 self
141 }
142}
143
144#[derive(Clone)]
146pub struct ToolExecutor {
147 tools: Arc<DashMap<String, Box<dyn DynTool>>>,
148 config: ExecutionConfig,
149 cache: ToolCallCache,
150}
151
152impl std::fmt::Debug for ToolExecutor {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 let tool_count = self.tools.len();
155 let cache_enabled = self.cache.stats().total_entries > 0;
156 f.debug_struct("ToolExecutor")
157 .field("tool_count", &tool_count)
158 .field("config", &self.config)
159 .field("cache_enabled", &cache_enabled)
160 .finish()
161 }
162}
163
164impl Default for ToolExecutor {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl ToolExecutor {
171 pub fn new() -> Self {
173 Self {
174 tools: Arc::new(DashMap::new()),
175 config: ExecutionConfig::default(),
176 cache: ToolCallCache::new(),
177 }
178 }
179
180 pub fn builder() -> ExecutorBuilder {
182 ExecutorBuilder::new()
183 }
184
185 pub fn with_cache_enabled(mut self, enabled: bool) -> Self {
187 self.cache = self.cache.with_enabled(enabled);
188 self
189 }
190
191 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
193 self.cache = self.cache.with_ttl(ttl);
194 self
195 }
196
197 pub fn with_cache_max_size(mut self, size: usize) -> Self {
199 self.cache = self.cache.with_max_size(size);
200 self
201 }
202
203 pub fn clear_cache(&self) {
205 self.cache.clear();
206 }
207
208 pub fn invalidate_cache_for_tool(&self, tool_name: &str) {
210 self.cache.invalidate_tool(tool_name);
211 }
212
213 pub fn cache_stats(&self) -> super::cache::CacheStats {
215 self.cache.stats()
216 }
217
218 pub fn add_dyn_tool(&self, tool: Box<dyn DynTool>) -> &Self {
220 let name = tool.name().to_string();
221 if self.tools.contains_key(&name) {
222 panic!("Tool '{}' is already registered", name);
223 }
224 self.tools.insert(name, tool);
225 self
226 }
227
228 pub fn try_add_dyn_tool(&self, tool: Box<dyn DynTool>) -> &Self {
230 let name = tool.name().to_string();
231 self.tools.entry(name).or_insert(tool);
232 self
233 }
234
235 pub fn unregister(&self, name: &str) -> ToolResult<()> {
237 if self.tools.remove(name).is_none() {
238 return Err(error_context().tool_not_found());
239 }
240 Ok(())
241 }
242
243 pub fn input_schema(&self, name: &str) -> Option<serde_json::Value> {
245 self.tools.get(name).map(|t| t.input_schema())
246 }
247
248 pub fn has_tool(&self, name: &str) -> bool {
250 self.tools.contains_key(name)
251 }
252
253 pub fn tool_names(&self) -> Vec<String> {
255 self.tools.iter().map(|entry| entry.key().clone()).collect()
256 }
257
258 fn get_tool(&self, name: &str) -> Option<Box<dyn DynTool>> {
259 self.tools.get(name).map(|t| t.clone_box())
260 }
261
262 pub async fn execute(
264 &self,
265 tool_name: &str,
266 input: serde_json::Value,
267 ) -> ToolResult<ExecutionResult> {
268 let start_time = Instant::now();
269 let mut retries = 0;
270 let retry_config = &self.config.retry_config;
271
272 let cache_key = CacheKey::new(tool_name.to_string(), input.clone());
274 if let Some(cached_result) = self.cache.get(&cache_key) {
275 let duration = start_time.elapsed();
276 return Ok(ExecutionResult::success(
277 tool_name.to_string(),
278 cached_result,
279 duration,
280 retries,
281 )
282 .with_metadata("cache_hit", serde_json::Value::Bool(true)));
283 }
284
285 loop {
286 match self.execute_once(tool_name, &input).await {
287 Ok(result) => {
288 let duration = start_time.elapsed();
289 self.cache.insert(cache_key, result.clone(), None);
291
292 return Ok(ExecutionResult::success(
293 tool_name.to_string(),
294 result,
295 duration,
296 retries,
297 )
298 .with_metadata("cache_hit", serde_json::Value::Bool(false)));
299 },
300 Err(error) => {
301 if retries >= retry_config.max_retries {
302 let duration = start_time.elapsed();
303 return Ok(ExecutionResult::failure(
304 tool_name.to_string(),
305 error.to_string(),
306 duration,
307 retries,
308 ));
309 }
310
311 retries += 1;
312
313 if self.config.enable_logging {
314 eprintln!("Tool execution failed (attempt {}): {}", retries, error);
315 }
316
317 let delay = retry_config.calculate_delay(retries);
319 tokio::time::sleep(delay).await;
320 },
321 }
322 }
323 }
324
325 pub async fn execute_simple(
327 &self,
328 tool_name: &str,
329 input: serde_json::Value,
330 ) -> ToolResult<serde_json::Value> {
331 let result = self.execute(tool_name, input).await?;
332 if result.success {
333 Ok(result.result)
334 } else {
335 Err(error_context()
336 .with_tool(tool_name)
337 .execution_failed(result.error.unwrap_or_else(|| "Unknown error".to_string())))
338 }
339 }
340
341 pub fn add_functions_from_dir_with_registry(
354 &self,
355 dir: impl AsRef<std::path::Path>,
356 handlers: &std::collections::HashMap<String, ToolHandler>,
357 strict: bool,
358 ) -> ToolResult<Vec<String>> {
359 use std::fs;
360
361 use serde_json::Value;
362 let dir = dir.as_ref();
363 let mut added = Vec::new();
364 let read_dir = fs::read_dir(dir).map_err(|e| {
365 error_context().invalid_parameters(format!(
366 "Failed to read dir {}: {}",
367 dir.display(),
368 e
369 ))
370 })?;
371 for entry in read_dir {
372 let entry = match entry {
373 Ok(e) => e,
374 Err(e) => {
375 return Err(
376 error_context().invalid_parameters(format!("Dir entry error: {}", e))
377 );
378 },
379 };
380 let path = entry.path();
381 if !path.is_file() {
382 continue;
383 }
384 if path.extension().and_then(|s| s.to_str()) != Some("json") {
385 continue;
386 }
387 let content = fs::read_to_string(&path).map_err(|e| {
388 error_context().invalid_parameters(format!(
389 "Failed to read {}: {}",
390 path.display(),
391 e
392 ))
393 })?;
394 let spec: Value = serde_json::from_str(&content).map_err(|e| {
395 error_context().invalid_parameters(format!(
396 "Invalid JSON in {}: {}",
397 path.display(),
398 e
399 ))
400 })?;
401
402 let (name, description, parameters) =
404 crate::toolkits::core::parse_function_spec_details(&spec).map_err(|e| {
405 error_context().invalid_parameters(format!(
406 "Failed to parse spec {}: {}",
407 path.display(),
408 e
409 ))
410 })?;
411
412 let handler = match handlers.get(&name) {
413 Some(h) => h.clone(),
414 None => {
415 if strict {
416 return Err(error_context().invalid_parameters(format!(
417 "No handler registered for function '{}' (file {})",
418 name,
419 path.display()
420 )));
421 } else {
422 continue;
424 }
425 },
426 };
427
428 let mut builder =
431 crate::toolkits::core::FunctionTool::builder(name.clone(), description);
432 if let Some(p) = parameters {
433 builder = builder.schema(p);
434 }
435 let tool = builder
436 .handler(move |args| {
437 let h = handler.clone();
438 h(args)
439 })
440 .build()?;
441
442 self.add_dyn_tool(Box::new(tool));
443 added.push(name);
444 }
445 Ok(added)
446 }
447
448 async fn execute_single_tool_call(&self, tc: &ToolCallMessage) -> TextMessage {
464 let id_opt = tc.id().map(|s| s.to_string());
465 let func_opt = tc.function();
466
467 if let Some(func) = func_opt {
468 let name = func.name().unwrap_or("").to_string();
469 let args_str = func.arguments().unwrap_or("{}");
470 let args_json: serde_json::Value = serde_json::from_str(args_str)
471 .unwrap_or_else(|_| serde_json::json!({ "_raw": args_str }));
472
473 let content_json = match self.execute_simple(&name, args_json).await {
474 Ok(v) => v,
475 Err(err) => serde_json::json!({
476 "error": { "type": "execution_failed", "message": err.to_string() }
477 }),
478 };
479
480 let s = serde_json::to_string(&content_json).unwrap_or_else(|_| "{}".to_string());
481
482 if let Some(id) = id_opt {
483 TextMessage::tool_with_id(s, id)
484 } else {
485 TextMessage::tool(s)
486 }
487 } else {
488 let s = serde_json::json!({
489 "error": { "type": "missing_function", "message": "tool_call.function is missing" }
490 })
491 .to_string();
492
493 if let Some(id) = id_opt {
494 TextMessage::tool_with_id(s, id)
495 } else {
496 TextMessage::tool(s)
497 }
498 }
499 }
500
501 pub async fn execute_tool_calls_parallel(&self, calls: &[ToolCallMessage]) -> Vec<TextMessage> {
502 let mut set = JoinSet::new();
503
504 let calls_vec = calls.to_vec();
506 for tc in calls_vec {
507 let this = self.clone();
508 set.spawn(async move { this.execute_single_tool_call(&tc).await });
509 }
510
511 let mut messages = Vec::with_capacity(calls.len());
512 while let Some(res) = set.join_next().await {
513 if let Ok(msg) = res {
514 messages.push(msg);
515 }
516 }
517 messages
518 }
519
520 pub async fn execute_tool_calls_ordered(&self, calls: &[ToolCallMessage]) -> Vec<TextMessage> {
539 use futures::future::join_all;
540
541 let calls_vec = calls.to_vec();
542 let futures: Vec<_> = calls_vec
543 .into_iter()
544 .map(|tc| {
545 let this = self.clone();
546 async move { this.execute_single_tool_call(&tc).await }
547 })
548 .collect();
549
550 join_all(futures).await
551 }
552
553 pub fn export_tool_as_function(&self, name: &str) -> Option<Tools> {
556 let tool = self.tools.get(name)?;
557 let meta = tool.metadata();
558 let schema = tool.input_schema();
559 let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
560 Some(Tools::Function { function: func })
561 }
562
563 pub fn export_all_tools_as_functions(&self) -> Vec<Tools> {
565 self.tools
566 .iter()
567 .map(|entry| {
568 let tool = entry.value();
569 let meta = tool.metadata();
570 let schema = tool.input_schema();
571 let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
572 Tools::Function { function: func }
573 })
574 .collect()
575 }
576 pub fn export_tools_filtered<F>(&self, mut filter: F) -> Vec<Tools>
578 where
579 F: FnMut(&crate::toolkits::core::ToolMetadata) -> bool,
580 {
581 self.tools
582 .iter()
583 .filter(|entry| filter(entry.value().metadata()))
584 .map(|entry| {
585 let tool = entry.value();
586 let meta = tool.metadata();
587 let schema = tool.input_schema();
588 let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
589 Tools::Function { function: func }
590 })
591 .collect()
592 }
593
594 async fn execute_once(
595 &self,
596 tool_name: &str,
597 input: &serde_json::Value,
598 ) -> ToolResult<serde_json::Value> {
599 let tool = self
600 .get_tool(tool_name)
601 .ok_or_else(|| error_context().with_tool(tool_name).tool_not_found())?;
602 let execution_future = tool.execute_json(input.clone());
603
604 match self.config.timeout {
605 Some(timeout_duration) => match timeout(timeout_duration, execution_future).await {
606 Ok(result) => result,
607 Err(_) => Err(error_context()
608 .with_tool(tool_name)
609 .timeout_error(timeout_duration)),
610 },
611 None => execution_future.await,
612 }
613 }
614
615 pub fn config(&self) -> &ExecutionConfig {
617 &self.config
618 }
619}
620
621pub struct ExecutorBuilder {
623 config: ExecutionConfig,
624 cache_config: Option<CacheConfig>,
625}
626
627#[derive(Clone)]
628struct CacheConfig {
629 enabled: bool,
630 ttl: Duration,
631 max_size: usize,
632}
633
634impl Default for ExecutorBuilder {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640impl ExecutorBuilder {
641 pub fn new() -> Self {
643 Self {
644 config: ExecutionConfig::default(),
645 cache_config: None,
646 }
647 }
648
649 pub fn timeout(mut self, timeout: Duration) -> Self {
651 self.config.timeout = Some(timeout);
652 self
653 }
654
655 pub fn retries(mut self, retries: u32) -> Self {
657 self.config.retry_config.max_retries = retries;
658 self
659 }
660
661 pub fn logging(mut self, enabled: bool) -> Self {
663 self.config.enable_logging = enabled;
664 self
665 }
666
667 pub fn enable_cache(mut self) -> Self {
669 self.cache_config
670 .get_or_insert(CacheConfig {
671 enabled: true,
672 ttl: Duration::from_secs(300),
673 max_size: 1000,
674 })
675 .enabled = true;
676 self
677 }
678
679 pub fn disable_cache(mut self) -> Self {
681 self.cache_config
682 .get_or_insert(CacheConfig {
683 enabled: false,
684 ttl: Duration::from_secs(300),
685 max_size: 1000,
686 })
687 .enabled = false;
688 self
689 }
690
691 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
693 let cfg = self.cache_config.get_or_insert(CacheConfig {
694 enabled: true,
695 ttl: Duration::from_secs(300),
696 max_size: 1000,
697 });
698 cfg.ttl = ttl;
699 self
700 }
701
702 pub fn cache_max_size(mut self, size: usize) -> Self {
704 let cfg = self.cache_config.get_or_insert(CacheConfig {
705 enabled: true,
706 ttl: Duration::from_secs(300),
707 max_size: 1000,
708 });
709 cfg.max_size = size;
710 self
711 }
712
713 pub fn build(self) -> ToolExecutor {
715 let cache = match self.cache_config {
716 Some(cfg) => ToolCallCache::new()
717 .with_enabled(cfg.enabled)
718 .with_ttl(cfg.ttl)
719 .with_max_size(cfg.max_size),
720 None => ToolCallCache::new(),
721 };
722
723 ToolExecutor {
724 tools: Arc::new(DashMap::new()),
725 config: self.config,
726 cache,
727 }
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use crate::toolkits::core::FunctionTool;
735
736 #[test]
737 fn test_retry_config_default() {
738 let config = RetryConfig::default();
739 assert_eq!(config.max_retries, 3);
740 assert_eq!(config.initial_delay, Duration::from_millis(100));
741 assert_eq!(config.max_delay, Duration::from_secs(30));
742 assert_eq!(config.backoff_multiplier, 2.0);
743 }
744
745 #[test]
746 fn test_retry_config_calculate_delay() {
747 let config = RetryConfig::default();
748
749 assert_eq!(config.calculate_delay(0), Duration::ZERO);
751
752 assert_eq!(config.calculate_delay(1), Duration::from_millis(100));
754
755 assert_eq!(config.calculate_delay(2), Duration::from_millis(200));
757
758 assert_eq!(config.calculate_delay(3), Duration::from_millis(400));
760
761 let config = RetryConfig {
763 max_retries: 10,
764 initial_delay: Duration::from_millis(500),
765 max_delay: Duration::from_secs(1),
766 backoff_multiplier: 3.0,
767 };
768 assert_eq!(config.calculate_delay(1), Duration::from_millis(500));
770 assert_eq!(config.calculate_delay(2), Duration::from_secs(1));
771 assert_eq!(config.calculate_delay(3), Duration::from_secs(1));
772 }
773
774 #[test]
775 fn test_execution_config_default() {
776 let config = ExecutionConfig::default();
777 assert_eq!(config.timeout, Some(Duration::from_secs(30)));
778 assert!(config.validate_parameters);
779 assert!(!config.enable_logging);
780 assert_eq!(config.retry_config.max_retries, 3);
781 }
782
783 #[test]
784 fn test_execution_result_success() {
785 let result = ExecutionResult::success(
786 "test_tool".to_string(),
787 serde_json::json!({"value": 42}),
788 Duration::from_millis(100),
789 2,
790 );
791
792 assert_eq!(result.tool_name, "test_tool");
793 assert_eq!(result.result, serde_json::json!({"value": 42}));
794 assert_eq!(result.duration, Duration::from_millis(100));
795 assert!(result.success);
796 assert!(result.error.is_none());
797 assert_eq!(result.retries, 2);
798 assert!(result.metadata.is_empty());
799 }
800
801 #[test]
802 fn test_execution_result_failure() {
803 let result = ExecutionResult::failure(
804 "test_tool".to_string(),
805 "Something went wrong".to_string(),
806 Duration::from_millis(50),
807 1,
808 );
809
810 assert_eq!(result.tool_name, "test_tool");
811 assert_eq!(result.result, serde_json::Value::Null);
812 assert_eq!(result.duration, Duration::from_millis(50));
813 assert!(!result.success);
814 assert_eq!(result.error, Some("Something went wrong".to_string()));
815 assert_eq!(result.retries, 1);
816 assert!(result.metadata.is_empty());
817 }
818
819 #[test]
820 fn test_execution_result_with_metadata() {
821 let mut result = ExecutionResult::success(
822 "test_tool".to_string(),
823 serde_json::json!({"value": 42}),
824 Duration::from_millis(100),
825 0,
826 );
827
828 result = result.with_metadata("key1", serde_json::json!("value1"));
829 result = result.with_metadata("key2", serde_json::json!({"nested": true}));
830
831 assert_eq!(result.metadata.len(), 2);
832 assert_eq!(
833 result.metadata.get("key1"),
834 Some(&serde_json::json!("value1"))
835 );
836 assert_eq!(
837 result.metadata.get("key2"),
838 Some(&serde_json::json!({"nested": true}))
839 );
840 }
841
842 #[test]
843 fn test_execution_result_serialization() {
844 let result = ExecutionResult::success(
845 "test_tool".to_string(),
846 serde_json::json!({"value": 42}),
847 Duration::from_millis(100),
848 0,
849 );
850
851 let json = serde_json::to_string(&result).unwrap();
852 assert!(json.contains("\"tool_name\":\"test_tool\""));
853 assert!(json.contains("\"success\":true"));
854 assert!(json.contains("\"value\":42"));
855 }
856
857 #[test]
858 fn test_tool_executor_default() {
859 let executor = ToolExecutor::new();
860 assert_eq!(executor.tool_names().len(), 0);
861 assert_eq!(executor.config.timeout, Some(Duration::from_secs(30)));
862 }
863
864 #[test]
865 fn test_tool_executor_register_and_unregister() {
866 let executor = ToolExecutor::new();
867
868 let tool = FunctionTool::builder("test_tool", "A test tool")
870 .handler(|_args| async move { Ok(serde_json::json!({"result": "success"})) })
871 .build()
872 .unwrap();
873
874 executor.add_dyn_tool(Box::new(tool));
876 assert_eq!(executor.tool_names().len(), 1);
877 assert!(executor.has_tool("test_tool"));
878
879 assert!(executor.unregister("test_tool").is_ok());
881 assert_eq!(executor.tool_names().len(), 0);
882 assert!(!executor.has_tool("test_tool"));
883 }
884
885 #[test]
886 fn test_tool_executor_duplicate_tool_panics() {
887 let executor = ToolExecutor::new();
888
889 let tool1 = FunctionTool::builder("duplicate_tool", "First tool")
890 .handler(|_args| async move { Ok(serde_json::json!({})) })
891 .build()
892 .unwrap();
893
894 let tool2 = FunctionTool::builder("duplicate_tool", "Second tool")
895 .handler(|_args| async move { Ok(serde_json::json!({})) })
896 .build()
897 .unwrap();
898
899 executor.add_dyn_tool(Box::new(tool1));
900
901 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
903 executor.add_dyn_tool(Box::new(tool2));
904 }));
905 assert!(result.is_err());
906 }
907
908 #[test]
909 fn test_tool_executor_try_add_dyn_tool() {
910 let executor = ToolExecutor::new();
911
912 let tool1 = FunctionTool::builder("test_tool", "First tool")
913 .handler(|_args| async move { Ok(serde_json::json!({})) })
914 .build()
915 .unwrap();
916
917 let tool2 = FunctionTool::builder("test_tool", "Second tool")
918 .handler(|_args| async move { Ok(serde_json::json!({})) })
919 .build()
920 .unwrap();
921
922 executor.try_add_dyn_tool(Box::new(tool1));
923 executor.try_add_dyn_tool(Box::new(tool2));
924
925 assert_eq!(executor.tool_names().len(), 1);
927 assert!(executor.has_tool("test_tool"));
928 }
929
930 #[test]
931 fn test_tool_executor_unregister_nonexistent_tool() {
932 let executor = ToolExecutor::new();
933 let result = executor.unregister("nonexistent_tool");
934 assert!(result.is_err());
935 }
936
937 #[test]
938 fn test_tool_executor_input_schema() {
939 let executor = ToolExecutor::new();
940
941 let schema = serde_json::json!({
942 "type": "object",
943 "properties": {
944 "name": {"type": "string"}
945 }
946 });
947
948 let tool = FunctionTool::builder("test_tool", "A test tool")
949 .schema(schema.clone())
950 .handler(|_args| async move { Ok(serde_json::json!({})) })
951 .build()
952 .unwrap();
953
954 executor.add_dyn_tool(Box::new(tool));
955
956 let retrieved_schema = executor.input_schema("test_tool");
957 assert!(retrieved_schema.is_some());
958 let retrieved = retrieved_schema.unwrap();
959
960 assert_eq!(retrieved["type"], "object");
962 assert_eq!(retrieved["properties"]["name"]["type"], "string");
963 assert_eq!(retrieved["additionalProperties"], false);
965 }
966
967 #[test]
968 fn test_tool_executor_input_schema_nonexistent() {
969 let executor = ToolExecutor::new();
970 let schema = executor.input_schema("nonexistent");
971 assert!(schema.is_none());
972 }
973
974 #[test]
975 fn test_tool_executor_tool_names() {
976 let executor = ToolExecutor::new();
977
978 let tool1 = FunctionTool::builder("tool1", "First tool")
979 .handler(|_args| async move { Ok(serde_json::json!({})) })
980 .build()
981 .unwrap();
982
983 let tool2 = FunctionTool::builder("tool2", "Second tool")
984 .handler(|_args| async move { Ok(serde_json::json!({})) })
985 .build()
986 .unwrap();
987
988 let tool3 = FunctionTool::builder("tool3", "Third tool")
989 .handler(|_args| async move { Ok(serde_json::json!({})) })
990 .build()
991 .unwrap();
992
993 executor.add_dyn_tool(Box::new(tool1));
994 executor.add_dyn_tool(Box::new(tool2));
995 executor.add_dyn_tool(Box::new(tool3));
996
997 let names = executor.tool_names();
998 assert_eq!(names.len(), 3);
999 assert!(names.contains(&"tool1".to_string()));
1000 assert!(names.contains(&"tool2".to_string()));
1001 assert!(names.contains(&"tool3".to_string()));
1002 }
1003
1004 #[tokio::test]
1005 async fn test_tool_executor_execute_success() {
1006 let executor = ToolExecutor::new();
1007
1008 let tool = FunctionTool::builder("add_tool", "Add two numbers")
1009 .property("a", serde_json::json!({"type": "number"}))
1010 .property("b", serde_json::json!({"type": "number"}))
1011 .handler(|args| async move {
1012 let a = args.get("a").and_then(|v| v.as_i64()).unwrap_or(0);
1013 let b = args.get("b").and_then(|v| v.as_i64()).unwrap_or(0);
1014 Ok(serde_json::json!({"result": a + b}))
1015 })
1016 .build()
1017 .unwrap();
1018
1019 executor.add_dyn_tool(Box::new(tool));
1020
1021 let input = serde_json::json!({"a": 5, "b": 3});
1022 let result = executor.execute("add_tool", input).await.unwrap();
1023
1024 assert!(result.success);
1025 assert_eq!(result.tool_name, "add_tool");
1026 assert_eq!(result.result, serde_json::json!({"result": 8}));
1027 assert_eq!(result.retries, 0);
1028 }
1029
1030 #[tokio::test]
1031 async fn test_tool_executor_execute_failure() {
1032 let executor = ToolExecutor::new();
1033
1034 let tool = FunctionTool::builder("failing_tool", "Always fails")
1035 .handler(|_args| async move {
1036 Err(error_context()
1037 .with_tool("failing_tool")
1038 .execution_failed("Intentional failure"))
1039 })
1040 .build()
1041 .unwrap();
1042
1043 executor.add_dyn_tool(Box::new(tool));
1044
1045 let input = serde_json::json!({});
1046 let result = executor.execute("failing_tool", input).await.unwrap();
1047
1048 assert!(!result.success);
1049 assert_eq!(result.tool_name, "failing_tool");
1050 assert!(result.error.is_some());
1051 }
1052
1053 #[tokio::test]
1054 async fn test_tool_executor_execute_nonexistent_tool() {
1055 let executor = ToolExecutor::new();
1056 let input = serde_json::json!({});
1057 let result = executor.execute("nonexistent_tool", input).await.unwrap();
1058
1059 assert!(!result.success);
1060 assert!(result.error.is_some());
1061 }
1062
1063 #[tokio::test]
1064 async fn test_tool_executor_execute_simple_success() {
1065 let executor = ToolExecutor::new();
1066
1067 let tool = FunctionTool::builder("echo_tool", "Echo input")
1068 .property("message", serde_json::json!({"type": "string"}))
1069 .handler(|args| async move { Ok(args) })
1070 .build()
1071 .unwrap();
1072
1073 executor.add_dyn_tool(Box::new(tool));
1074
1075 let input = serde_json::json!({"message": "hello"});
1076 let result = executor.execute_simple("echo_tool", input).await.unwrap();
1077
1078 assert_eq!(result, serde_json::json!({"message": "hello"}));
1079 }
1080
1081 #[tokio::test]
1082 async fn test_tool_executor_execute_simple_failure() {
1083 let executor = ToolExecutor::new();
1084
1085 let tool = FunctionTool::builder("failing_tool", "Always fails")
1086 .handler(|_args| async move {
1087 Err(error_context()
1088 .with_tool("failing_tool")
1089 .execution_failed("Intentional failure"))
1090 })
1091 .build()
1092 .unwrap();
1093
1094 executor.add_dyn_tool(Box::new(tool));
1095
1096 let input = serde_json::json!({});
1097 let result = executor.execute_simple("failing_tool", input).await;
1098
1099 assert!(result.is_err());
1100 }
1101
1102 #[tokio::test]
1103 async fn test_tool_executor_timeout() {
1104 let executor = ToolExecutor::builder()
1105 .timeout(Duration::from_millis(100))
1106 .build();
1107
1108 let tool = FunctionTool::builder("slow_tool", "Slow tool")
1109 .handler(|_args| async move {
1110 tokio::time::sleep(Duration::from_secs(1)).await;
1111 Ok(serde_json::json!({"done": true}))
1112 })
1113 .build()
1114 .unwrap();
1115
1116 executor.add_dyn_tool(Box::new(tool));
1117
1118 let input = serde_json::json!({});
1119 let result = executor.execute("slow_tool", input).await.unwrap();
1120
1121 assert!(!result.success);
1122 assert!(result.error.is_some());
1123 assert!(result.error.unwrap().contains("Timeout"));
1124 }
1125
1126 #[tokio::test]
1127 async fn test_tool_executor_retry() {
1128 let executor = ToolExecutor::builder()
1129 .retries(2)
1130 .timeout(Duration::from_secs(30))
1131 .build();
1132
1133 let attempt_counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1134 let counter_clone = attempt_counter.clone();
1135
1136 let tool = FunctionTool::builder("flaky_tool", "Flaky tool")
1137 .handler(move |_args| {
1138 let counter = counter_clone.clone();
1139 async move {
1140 let attempts = counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1141 if attempts < 2 {
1142 Err(error_context()
1143 .with_tool("flaky_tool")
1144 .execution_failed("Temporary failure"))
1145 } else {
1146 Ok(serde_json::json!({"attempts": attempts + 1}))
1147 }
1148 }
1149 })
1150 .build()
1151 .unwrap();
1152
1153 executor.add_dyn_tool(Box::new(tool));
1154
1155 let input = serde_json::json!({});
1156 let result = executor.execute("flaky_tool", input).await.unwrap();
1157
1158 assert!(result.success);
1159 assert_eq!(result.retries, 2);
1160 }
1161
1162 #[test]
1163 fn test_executor_builder_default() {
1164 let builder = ExecutorBuilder::new();
1165 assert_eq!(builder.config.timeout, Some(Duration::from_secs(30)));
1166 assert_eq!(builder.config.retry_config.max_retries, 3);
1167 }
1168
1169 #[test]
1170 fn test_executor_builder_timeout() {
1171 let builder = ExecutorBuilder::new().timeout(Duration::from_secs(60));
1172 assert_eq!(builder.config.timeout, Some(Duration::from_secs(60)));
1173 }
1174
1175 #[test]
1176 fn test_executor_builder_retries() {
1177 let builder = ExecutorBuilder::new().retries(5);
1178 assert_eq!(builder.config.retry_config.max_retries, 5);
1179 }
1180
1181 #[test]
1182 fn test_executor_builder_logging() {
1183 let builder = ExecutorBuilder::new().logging(true);
1184 assert!(builder.config.enable_logging);
1185 }
1186
1187 #[test]
1188 fn test_executor_builder_build() {
1189 let executor = ExecutorBuilder::new()
1190 .timeout(Duration::from_secs(60))
1191 .retries(5)
1192 .logging(true)
1193 .build();
1194
1195 assert_eq!(executor.config.timeout, Some(Duration::from_secs(60)));
1196 assert_eq!(executor.config.retry_config.max_retries, 5);
1197 assert!(executor.config.enable_logging);
1198 }
1199
1200 #[test]
1201 fn test_executor_builder_chainable() {
1202 let builder = ExecutorBuilder::new()
1203 .timeout(Duration::from_secs(45))
1204 .retries(3)
1205 .logging(false)
1206 .timeout(Duration::from_secs(50))
1207 .retries(4)
1208 .logging(true);
1209
1210 assert_eq!(builder.config.timeout, Some(Duration::from_secs(50)));
1211 assert_eq!(builder.config.retry_config.max_retries, 4);
1212 assert!(builder.config.enable_logging);
1213 }
1214
1215 #[test]
1216 fn test_export_tool_as_function() {
1217 let executor = ToolExecutor::new();
1218
1219 let tool = FunctionTool::builder("greet_tool", "Greet someone")
1220 .handler(|_args| async move { Ok(serde_json::json!({"greeting": "hello"})) })
1221 .build()
1222 .unwrap();
1223
1224 executor.add_dyn_tool(Box::new(tool));
1225
1226 let exported = executor.export_tool_as_function("greet_tool");
1227 assert!(exported.is_some());
1228
1229 if let Some(Tools::Function { function }) = exported {
1230 assert_eq!(function.name, "greet_tool");
1231 assert_eq!(function.description, "Greet someone");
1232 assert!(function.parameters.is_some());
1234 } else {
1235 panic!("Expected Tools::Function");
1236 }
1237 }
1238
1239 #[test]
1240 fn test_export_tool_as_function_nonexistent() {
1241 let executor = ToolExecutor::new();
1242 let exported = executor.export_tool_as_function("nonexistent");
1243 assert!(exported.is_none());
1244 }
1245
1246 #[test]
1247 fn test_export_all_tools_as_functions() {
1248 let executor = ToolExecutor::new();
1249
1250 let tool1 = FunctionTool::builder("tool1", "First tool")
1251 .handler(|_args| async move { Ok(serde_json::json!({})) })
1252 .build()
1253 .unwrap();
1254
1255 let tool2 = FunctionTool::builder("tool2", "Second tool")
1256 .handler(|_args| async move { Ok(serde_json::json!({})) })
1257 .build()
1258 .unwrap();
1259
1260 executor.add_dyn_tool(Box::new(tool1));
1261 executor.add_dyn_tool(Box::new(tool2));
1262
1263 let exported = executor.export_all_tools_as_functions();
1264 assert_eq!(exported.len(), 2);
1265
1266 let names: Vec<_> = exported
1267 .iter()
1268 .filter_map(|t| match t {
1269 Tools::Function { function } => Some(function.name.clone()),
1270 _ => None,
1271 })
1272 .collect();
1273
1274 assert!(names.contains(&"tool1".to_string()));
1275 assert!(names.contains(&"tool2".to_string()));
1276 }
1277
1278 #[test]
1279 fn test_export_tools_filtered() {
1280 let executor = ToolExecutor::new();
1281
1282 let tool1 = FunctionTool::builder("math_tool", "Math operations")
1283 .metadata(|m| m.version("1.0.0"))
1284 .handler(|_args| async move { Ok(serde_json::json!({})) })
1285 .build()
1286 .unwrap();
1287
1288 let tool2 = FunctionTool::builder("text_tool", "Text operations")
1289 .metadata(|m| m.version("2.0.0"))
1290 .handler(|_args| async move { Ok(serde_json::json!({})) })
1291 .build()
1292 .unwrap();
1293
1294 executor.add_dyn_tool(Box::new(tool1));
1295 executor.add_dyn_tool(Box::new(tool2));
1296
1297 let exported = executor.export_tools_filtered(|meta| meta.version == "1.0.0");
1298 assert_eq!(exported.len(), 1);
1299
1300 if let Some(Tools::Function { function }) = exported.first() {
1301 assert_eq!(function.name, "math_tool");
1302 } else {
1303 panic!("Expected Tools::Function");
1304 }
1305 }
1306
1307 #[test]
1308 fn test_execution_result_metadata_serialization() {
1309 let result = ExecutionResult::success(
1310 "test_tool".to_string(),
1311 serde_json::json!({"value": 42}),
1312 Duration::from_millis(100),
1313 0,
1314 )
1315 .with_metadata("key1", serde_json::json!("value1"))
1316 .with_metadata("key2", serde_json::json!(123));
1317
1318 let json = serde_json::to_string(&result).unwrap();
1319 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1320
1321 assert_eq!(parsed["metadata"]["key1"], "value1");
1322 assert_eq!(parsed["metadata"]["key2"], 123);
1323 }
1324
1325 #[test]
1326 fn test_execution_result_timestamp() {
1327 let before = std::time::SystemTime::now();
1328 let result = ExecutionResult::success(
1329 "test_tool".to_string(),
1330 serde_json::json!({"value": 42}),
1331 Duration::from_millis(100),
1332 0,
1333 );
1334 let after = std::time::SystemTime::now();
1335
1336 assert!(result.timestamp >= before && result.timestamp <= after);
1337 }
1338}