Skip to main content

zai_rs/toolkits/
executor.rs

1//! Enhanced tool executor with type-safe builder pattern
2
3use std::{
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use tokio::{task::JoinSet, time::timeout};
11
12use super::cache::{CacheKey, ToolCallCache};
13use crate::{
14    model::{
15        chat_base_response::ToolCallMessage,
16        chat_message_types::TextMessage,
17        tools::{Function, Tools},
18    },
19    toolkits::{
20        core::DynTool,
21        error::{ToolResult, error_context},
22    },
23};
24
25/// Type alias for the complex handler type to reduce complexity warnings
26type ToolHandler = std::sync::Arc<
27    dyn Fn(
28            serde_json::Value,
29        ) -> std::pin::Pin<
30            Box<
31                dyn std::future::Future<
32                        Output = crate::toolkits::error::ToolResult<serde_json::Value>,
33                    > + Send,
34            >,
35        > + Send
36        + Sync,
37>;
38
39/// Enhanced retry configuration with exponential backoff
40#[derive(Debug, Clone)]
41pub struct RetryConfig {
42    pub max_retries: u32,
43    pub initial_delay: Duration,
44    pub max_delay: Duration,
45    pub backoff_multiplier: f64,
46}
47
48impl Default for RetryConfig {
49    fn default() -> Self {
50        Self {
51            max_retries: 3,
52            initial_delay: Duration::from_millis(100),
53            max_delay: Duration::from_secs(30),
54            backoff_multiplier: 2.0,
55        }
56    }
57}
58
59impl RetryConfig {
60    pub fn calculate_delay(&self, attempt: u32) -> Duration {
61        if attempt == 0 {
62            return Duration::ZERO;
63        }
64
65        let delay_ms = self.initial_delay.as_millis() as f64
66            * self.backoff_multiplier.powi((attempt - 1) as i32);
67        let delay_ms = delay_ms.min(self.max_delay.as_millis() as f64) as u64;
68
69        Duration::from_millis(delay_ms)
70    }
71}
72
73/// Execution configuration with type-safe builder
74#[derive(Debug, Clone)]
75pub struct ExecutionConfig {
76    pub timeout: Option<Duration>,
77    pub retry_config: RetryConfig,
78    pub validate_parameters: bool,
79    pub enable_logging: bool,
80}
81
82impl Default for ExecutionConfig {
83    fn default() -> Self {
84        Self {
85            timeout: Some(Duration::from_secs(30)),
86            retry_config: RetryConfig::default(),
87            validate_parameters: true,
88            enable_logging: false,
89        }
90    }
91}
92
93/// Execution result with enhanced metadata
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ExecutionResult {
96    pub tool_name: String,
97    pub result: serde_json::Value,
98    pub duration: Duration,
99    pub success: bool,
100    pub error: Option<String>,
101    pub retries: u32,
102    pub timestamp: std::time::SystemTime,
103    pub metadata: std::collections::HashMap<String, serde_json::Value>,
104}
105
106impl ExecutionResult {
107    pub fn success(
108        tool_name: String,
109        result: serde_json::Value,
110        duration: Duration,
111        retries: u32,
112    ) -> Self {
113        Self {
114            tool_name,
115            result,
116            duration,
117            success: true,
118            error: None,
119            retries,
120            timestamp: std::time::SystemTime::now(),
121            metadata: std::collections::HashMap::new(),
122        }
123    }
124
125    pub fn failure(tool_name: String, error: String, duration: Duration, retries: u32) -> Self {
126        Self {
127            tool_name,
128            result: serde_json::Value::Null,
129            duration,
130            success: false,
131            error: Some(error),
132            retries,
133            timestamp: std::time::SystemTime::now(),
134            metadata: std::collections::HashMap::new(),
135        }
136    }
137
138    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
139        self.metadata.insert(key.into(), value);
140        self
141    }
142}
143
144/// Enhanced tool executor with built-in registry and fluent API
145#[derive(Clone)]
146pub struct ToolExecutor {
147    tools: Arc<DashMap<String, Box<dyn DynTool>>>,
148    config: ExecutionConfig,
149    cache: ToolCallCache,
150}
151
152impl std::fmt::Debug for ToolExecutor {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        let tool_count = self.tools.len();
155        let cache_enabled = self.cache.stats().total_entries > 0;
156        f.debug_struct("ToolExecutor")
157            .field("tool_count", &tool_count)
158            .field("config", &self.config)
159            .field("cache_enabled", &cache_enabled)
160            .finish()
161    }
162}
163
164impl Default for ToolExecutor {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl ToolExecutor {
171    /// Create a new executor with default config
172    pub fn new() -> Self {
173        Self {
174            tools: Arc::new(DashMap::new()),
175            config: ExecutionConfig::default(),
176            cache: ToolCallCache::new(),
177        }
178    }
179
180    /// Create an executor builder for fluent API
181    pub fn builder() -> ExecutorBuilder {
182        ExecutorBuilder::new()
183    }
184
185    /// Enable or disable tool call result caching
186    pub fn with_cache_enabled(mut self, enabled: bool) -> Self {
187        self.cache = self.cache.with_enabled(enabled);
188        self
189    }
190
191    /// Set cache TTL (time-to-live)
192    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
193        self.cache = self.cache.with_ttl(ttl);
194        self
195    }
196
197    /// Set maximum cache size
198    pub fn with_cache_max_size(mut self, size: usize) -> Self {
199        self.cache = self.cache.with_max_size(size);
200        self
201    }
202
203    /// Clear the cache
204    pub fn clear_cache(&self) {
205        self.cache.clear();
206    }
207
208    /// Invalidate cache for a specific tool
209    pub fn invalidate_cache_for_tool(&self, tool_name: &str) {
210        self.cache.invalidate_tool(tool_name);
211    }
212
213    /// Get cache statistics
214    pub fn cache_stats(&self) -> super::cache::CacheStats {
215        self.cache.stats()
216    }
217
218    /// Chain-friendly: add a dynamic tool (panics on error)
219    pub fn add_dyn_tool(&self, tool: Box<dyn DynTool>) -> &Self {
220        let name = tool.name().to_string();
221        if self.tools.contains_key(&name) {
222            panic!("Tool '{}' is already registered", name);
223        }
224        self.tools.insert(name, tool);
225        self
226    }
227
228    /// Chain-friendly: try to add a dynamic tool (ignores error)
229    pub fn try_add_dyn_tool(&self, tool: Box<dyn DynTool>) -> &Self {
230        let name = tool.name().to_string();
231        self.tools.entry(name).or_insert(tool);
232        self
233    }
234
235    /// Unregister a tool
236    pub fn unregister(&self, name: &str) -> ToolResult<()> {
237        if self.tools.remove(name).is_none() {
238            return Err(error_context().tool_not_found());
239        }
240        Ok(())
241    }
242
243    /// Get input schema for a tool
244    pub fn input_schema(&self, name: &str) -> Option<serde_json::Value> {
245        self.tools.get(name).map(|t| t.input_schema())
246    }
247
248    /// Check if tool exists
249    pub fn has_tool(&self, name: &str) -> bool {
250        self.tools.contains_key(name)
251    }
252
253    /// List tool names
254    pub fn tool_names(&self) -> Vec<String> {
255        self.tools.iter().map(|entry| entry.key().clone()).collect()
256    }
257
258    fn get_tool(&self, name: &str) -> Option<Box<dyn DynTool>> {
259        self.tools.get(name).map(|t| t.clone_box())
260    }
261
262    /// Execute a tool with detailed result and exponential backoff
263    pub async fn execute(
264        &self,
265        tool_name: &str,
266        input: serde_json::Value,
267    ) -> ToolResult<ExecutionResult> {
268        let start_time = Instant::now();
269        let mut retries = 0;
270        let retry_config = &self.config.retry_config;
271
272        // Check cache first
273        let cache_key = CacheKey::new(tool_name.to_string(), input.clone());
274        if let Some(cached_result) = self.cache.get(&cache_key) {
275            let duration = start_time.elapsed();
276            return Ok(ExecutionResult::success(
277                tool_name.to_string(),
278                cached_result,
279                duration,
280                retries,
281            )
282            .with_metadata("cache_hit", serde_json::Value::Bool(true)));
283        }
284
285        loop {
286            match self.execute_once(tool_name, &input).await {
287                Ok(result) => {
288                    let duration = start_time.elapsed();
289                    // Cache the successful result
290                    self.cache.insert(cache_key, result.clone(), None);
291
292                    return Ok(ExecutionResult::success(
293                        tool_name.to_string(),
294                        result,
295                        duration,
296                        retries,
297                    )
298                    .with_metadata("cache_hit", serde_json::Value::Bool(false)));
299                },
300                Err(error) => {
301                    if retries >= retry_config.max_retries {
302                        let duration = start_time.elapsed();
303                        return Ok(ExecutionResult::failure(
304                            tool_name.to_string(),
305                            error.to_string(),
306                            duration,
307                            retries,
308                        ));
309                    }
310
311                    retries += 1;
312
313                    if self.config.enable_logging {
314                        eprintln!("Tool execution failed (attempt {}): {}", retries, error);
315                    }
316
317                    // Use exponential backoff
318                    let delay = retry_config.calculate_delay(retries);
319                    tokio::time::sleep(delay).await;
320                },
321            }
322        }
323    }
324
325    /// Execute a tool and return only the result
326    pub async fn execute_simple(
327        &self,
328        tool_name: &str,
329        input: serde_json::Value,
330    ) -> ToolResult<serde_json::Value> {
331        let result = self.execute(tool_name, input).await?;
332        if result.success {
333            Ok(result.result)
334        } else {
335            Err(error_context()
336                .with_tool(tool_name)
337                .execution_failed(result.error.unwrap_or_else(|| "Unknown error".to_string())))
338        }
339    }
340
341    /// Bulk load function specs from a directory of .json files and register
342    /// them with handlers.
343    ///
344    /// - Each file should contain either of the following shapes:
345    ///   1) {"name":..., "description":..., "parameters": {...}}
346    ///   2) {"type":"function", "function": {"name":..., "description":...,
347    ///      "parameters": {...}}}
348    /// - `handlers` maps function `name` -> handler closure
349    /// - `strict`: when true, missing handler for any spec will return error;
350    ///   when false, specs without handlers are skipped
351    ///
352    /// Returns the list of function names successfully registered.
353    pub fn add_functions_from_dir_with_registry(
354        &self,
355        dir: impl AsRef<std::path::Path>,
356        handlers: &std::collections::HashMap<String, ToolHandler>,
357        strict: bool,
358    ) -> ToolResult<Vec<String>> {
359        use std::fs;
360
361        use serde_json::Value;
362        let dir = dir.as_ref();
363        let mut added = Vec::new();
364        let read_dir = fs::read_dir(dir).map_err(|e| {
365            error_context().invalid_parameters(format!(
366                "Failed to read dir {}: {}",
367                dir.display(),
368                e
369            ))
370        })?;
371        for entry in read_dir {
372            let entry = match entry {
373                Ok(e) => e,
374                Err(e) => {
375                    return Err(
376                        error_context().invalid_parameters(format!("Dir entry error: {}", e))
377                    );
378                },
379            };
380            let path = entry.path();
381            if !path.is_file() {
382                continue;
383            }
384            if path.extension().and_then(|s| s.to_str()) != Some("json") {
385                continue;
386            }
387            let content = fs::read_to_string(&path).map_err(|e| {
388                error_context().invalid_parameters(format!(
389                    "Failed to read {}: {}",
390                    path.display(),
391                    e
392                ))
393            })?;
394            let spec: Value = serde_json::from_str(&content).map_err(|e| {
395                error_context().invalid_parameters(format!(
396                    "Invalid JSON in {}: {}",
397                    path.display(),
398                    e
399                ))
400            })?;
401
402            // Extract name/description/parameters from spec
403            let (name, description, parameters) =
404                crate::toolkits::core::parse_function_spec_details(&spec).map_err(|e| {
405                    error_context().invalid_parameters(format!(
406                        "Failed to parse spec {}: {}",
407                        path.display(),
408                        e
409                    ))
410                })?;
411
412            let handler = match handlers.get(&name) {
413                Some(h) => h.clone(),
414                None => {
415                    if strict {
416                        return Err(error_context().invalid_parameters(format!(
417                            "No handler registered for function '{}' (file {})",
418                            name,
419                            path.display()
420                        )));
421                    } else {
422                        // skip silently
423                        continue;
424                    }
425                },
426            };
427
428            // Build FunctionTool via existing builder path (will auto-complete schema
429            // defaults)
430            let mut builder =
431                crate::toolkits::core::FunctionTool::builder(name.clone(), description);
432            if let Some(p) = parameters {
433                builder = builder.schema(p);
434            }
435            let tool = builder
436                .handler(move |args| {
437                    let h = handler.clone();
438                    h(args)
439                })
440                .build()?;
441
442            self.add_dyn_tool(Box::new(tool));
443            added.push(name);
444        }
445        Ok(added)
446    }
447
448    /// Execute LLM tool_calls in parallel and return `TextMessage::tool`
449    /// messages.
450    ///
451    /// Behavior:
452    /// - Parses each ToolCallMessage's function.arguments (stringified JSON
453    ///   supported)
454    /// - Runs all tools concurrently using this executor
455    /// - Captures errors per-call and encodes them as JSON: { "error": {
456    ///   "type": "...", "message": "..." } }
457    /// - Preserves tool_call `id` by emitting TextMessage::tool_with_id when
458    ///   present
459    ///
460    /// Returns:
461    /// - `Vec<TextMessage>` ready to be appended to ChatCompletion as tool
462    ///   messages.
463    async fn execute_single_tool_call(&self, tc: &ToolCallMessage) -> TextMessage {
464        let id_opt = tc.id().map(|s| s.to_string());
465        let func_opt = tc.function();
466
467        if let Some(func) = func_opt {
468            let name = func.name().unwrap_or("").to_string();
469            let args_str = func.arguments().unwrap_or("{}");
470            let args_json: serde_json::Value = serde_json::from_str(args_str)
471                .unwrap_or_else(|_| serde_json::json!({ "_raw": args_str }));
472
473            let content_json = match self.execute_simple(&name, args_json).await {
474                Ok(v) => v,
475                Err(err) => serde_json::json!({
476                    "error": { "type": "execution_failed", "message": err.to_string() }
477                }),
478            };
479
480            let s = serde_json::to_string(&content_json).unwrap_or_else(|_| "{}".to_string());
481
482            if let Some(id) = id_opt {
483                TextMessage::tool_with_id(s, id)
484            } else {
485                TextMessage::tool(s)
486            }
487        } else {
488            let s = serde_json::json!({
489                "error": { "type": "missing_function", "message": "tool_call.function is missing" }
490            })
491            .to_string();
492
493            if let Some(id) = id_opt {
494                TextMessage::tool_with_id(s, id)
495            } else {
496                TextMessage::tool(s)
497            }
498        }
499    }
500
501    pub async fn execute_tool_calls_parallel(&self, calls: &[ToolCallMessage]) -> Vec<TextMessage> {
502        let mut set = JoinSet::new();
503
504        // Clone the calls to avoid borrowing issues
505        let calls_vec = calls.to_vec();
506        for tc in calls_vec {
507            let this = self.clone();
508            set.spawn(async move { this.execute_single_tool_call(&tc).await });
509        }
510
511        let mut messages = Vec::with_capacity(calls.len());
512        while let Some(res) = set.join_next().await {
513            if let Ok(msg) = res {
514                messages.push(msg);
515            }
516        }
517        messages
518    }
519
520    /// Execute LLM tool_calls in parallel with result ordering preserved
521    ///
522    /// This method guarantees that results are returned in the same order as
523    /// the input calls, which is important for maintaining conversation
524    /// context in LLM interactions.
525    ///
526    /// Behavior:
527    /// - Parses each ToolCallMessage's function.arguments (stringified JSON
528    ///   supported)
529    /// - Runs all tools concurrently using this executor
530    /// - Preserves the original order of tool calls in results
531    /// - Captures errors per-call and encodes them as JSON
532    /// - Preserves tool_call `id` by emitting TextMessage::tool_with_id when
533    ///   present
534    ///
535    /// Returns:
536    /// - Vec<TextMessage> in the same order as input calls, ready for
537    ///   ChatCompletion
538    pub async fn execute_tool_calls_ordered(&self, calls: &[ToolCallMessage]) -> Vec<TextMessage> {
539        use futures::future::join_all;
540
541        let calls_vec = calls.to_vec();
542        let futures: Vec<_> = calls_vec
543            .into_iter()
544            .map(|tc| {
545                let this = self.clone();
546                async move { this.execute_single_tool_call(&tc).await }
547            })
548            .collect();
549
550        join_all(futures).await
551    }
552
553    /// Export a single registered tool as Tools::Function (for LLM function
554    /// calling)
555    pub fn export_tool_as_function(&self, name: &str) -> Option<Tools> {
556        let tool = self.tools.get(name)?;
557        let meta = tool.metadata();
558        let schema = tool.input_schema();
559        let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
560        Some(Tools::Function { function: func })
561    }
562
563    /// Export all registered tools as a Vec<Tools::Function>
564    pub fn export_all_tools_as_functions(&self) -> Vec<Tools> {
565        self.tools
566            .iter()
567            .map(|entry| {
568                let tool = entry.value();
569                let meta = tool.metadata();
570                let schema = tool.input_schema();
571                let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
572                Tools::Function { function: func }
573            })
574            .collect()
575    }
576    /// Export all registered tools with a metadata filter as Tools::Function
577    pub fn export_tools_filtered<F>(&self, mut filter: F) -> Vec<Tools>
578    where
579        F: FnMut(&crate::toolkits::core::ToolMetadata) -> bool,
580    {
581        self.tools
582            .iter()
583            .filter(|entry| filter(entry.value().metadata()))
584            .map(|entry| {
585                let tool = entry.value();
586                let meta = tool.metadata();
587                let schema = tool.input_schema();
588                let func = Function::new(meta.name.clone(), meta.description.clone(), schema);
589                Tools::Function { function: func }
590            })
591            .collect()
592    }
593
594    async fn execute_once(
595        &self,
596        tool_name: &str,
597        input: &serde_json::Value,
598    ) -> ToolResult<serde_json::Value> {
599        let tool = self
600            .get_tool(tool_name)
601            .ok_or_else(|| error_context().with_tool(tool_name).tool_not_found())?;
602        let execution_future = tool.execute_json(input.clone());
603
604        match self.config.timeout {
605            Some(timeout_duration) => match timeout(timeout_duration, execution_future).await {
606                Ok(result) => result,
607                Err(_) => Err(error_context()
608                    .with_tool(tool_name)
609                    .timeout_error(timeout_duration)),
610            },
611            None => execution_future.await,
612        }
613    }
614
615    /// Get the config
616    pub fn config(&self) -> &ExecutionConfig {
617        &self.config
618    }
619}
620
621/// Builder for creating tool executors with fluent API
622pub struct ExecutorBuilder {
623    config: ExecutionConfig,
624    cache_config: Option<CacheConfig>,
625}
626
627#[derive(Clone)]
628struct CacheConfig {
629    enabled: bool,
630    ttl: Duration,
631    max_size: usize,
632}
633
634impl Default for ExecutorBuilder {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640impl ExecutorBuilder {
641    /// Create a new executor builder
642    pub fn new() -> Self {
643        Self {
644            config: ExecutionConfig::default(),
645            cache_config: None,
646        }
647    }
648
649    /// Set timeout for tool execution
650    pub fn timeout(mut self, timeout: Duration) -> Self {
651        self.config.timeout = Some(timeout);
652        self
653    }
654
655    /// Set maximum number of retries
656    pub fn retries(mut self, retries: u32) -> Self {
657        self.config.retry_config.max_retries = retries;
658        self
659    }
660
661    /// Enable or disable logging
662    pub fn logging(mut self, enabled: bool) -> Self {
663        self.config.enable_logging = enabled;
664        self
665    }
666
667    /// Enable tool call result caching
668    pub fn enable_cache(mut self) -> Self {
669        self.cache_config
670            .get_or_insert(CacheConfig {
671                enabled: true,
672                ttl: Duration::from_secs(300),
673                max_size: 1000,
674            })
675            .enabled = true;
676        self
677    }
678
679    /// Disable tool call result caching
680    pub fn disable_cache(mut self) -> Self {
681        self.cache_config
682            .get_or_insert(CacheConfig {
683                enabled: false,
684                ttl: Duration::from_secs(300),
685                max_size: 1000,
686            })
687            .enabled = false;
688        self
689    }
690
691    /// Set cache TTL
692    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
693        let cfg = self.cache_config.get_or_insert(CacheConfig {
694            enabled: true,
695            ttl: Duration::from_secs(300),
696            max_size: 1000,
697        });
698        cfg.ttl = ttl;
699        self
700    }
701
702    /// Set maximum cache size
703    pub fn cache_max_size(mut self, size: usize) -> Self {
704        let cfg = self.cache_config.get_or_insert(CacheConfig {
705            enabled: true,
706            ttl: Duration::from_secs(300),
707            max_size: 1000,
708        });
709        cfg.max_size = size;
710        self
711    }
712
713    /// Build the final executor
714    pub fn build(self) -> ToolExecutor {
715        let cache = match self.cache_config {
716            Some(cfg) => ToolCallCache::new()
717                .with_enabled(cfg.enabled)
718                .with_ttl(cfg.ttl)
719                .with_max_size(cfg.max_size),
720            None => ToolCallCache::new(),
721        };
722
723        ToolExecutor {
724            tools: Arc::new(DashMap::new()),
725            config: self.config,
726            cache,
727        }
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use crate::toolkits::core::FunctionTool;
735
736    #[test]
737    fn test_retry_config_default() {
738        let config = RetryConfig::default();
739        assert_eq!(config.max_retries, 3);
740        assert_eq!(config.initial_delay, Duration::from_millis(100));
741        assert_eq!(config.max_delay, Duration::from_secs(30));
742        assert_eq!(config.backoff_multiplier, 2.0);
743    }
744
745    #[test]
746    fn test_retry_config_calculate_delay() {
747        let config = RetryConfig::default();
748
749        // First attempt should have zero delay
750        assert_eq!(config.calculate_delay(0), Duration::ZERO);
751
752        // Second attempt should have initial delay
753        assert_eq!(config.calculate_delay(1), Duration::from_millis(100));
754
755        // Third attempt should double (100 * 2)
756        assert_eq!(config.calculate_delay(2), Duration::from_millis(200));
757
758        // Fourth attempt should quadruple (100 * 2^2)
759        assert_eq!(config.calculate_delay(3), Duration::from_millis(400));
760
761        // Test with exponential growth that exceeds max_delay
762        let config = RetryConfig {
763            max_retries: 10,
764            initial_delay: Duration::from_millis(500),
765            max_delay: Duration::from_secs(1),
766            backoff_multiplier: 3.0,
767        };
768        // 500ms, then 1500ms (capped at 1000ms)
769        assert_eq!(config.calculate_delay(1), Duration::from_millis(500));
770        assert_eq!(config.calculate_delay(2), Duration::from_secs(1));
771        assert_eq!(config.calculate_delay(3), Duration::from_secs(1));
772    }
773
774    #[test]
775    fn test_execution_config_default() {
776        let config = ExecutionConfig::default();
777        assert_eq!(config.timeout, Some(Duration::from_secs(30)));
778        assert!(config.validate_parameters);
779        assert!(!config.enable_logging);
780        assert_eq!(config.retry_config.max_retries, 3);
781    }
782
783    #[test]
784    fn test_execution_result_success() {
785        let result = ExecutionResult::success(
786            "test_tool".to_string(),
787            serde_json::json!({"value": 42}),
788            Duration::from_millis(100),
789            2,
790        );
791
792        assert_eq!(result.tool_name, "test_tool");
793        assert_eq!(result.result, serde_json::json!({"value": 42}));
794        assert_eq!(result.duration, Duration::from_millis(100));
795        assert!(result.success);
796        assert!(result.error.is_none());
797        assert_eq!(result.retries, 2);
798        assert!(result.metadata.is_empty());
799    }
800
801    #[test]
802    fn test_execution_result_failure() {
803        let result = ExecutionResult::failure(
804            "test_tool".to_string(),
805            "Something went wrong".to_string(),
806            Duration::from_millis(50),
807            1,
808        );
809
810        assert_eq!(result.tool_name, "test_tool");
811        assert_eq!(result.result, serde_json::Value::Null);
812        assert_eq!(result.duration, Duration::from_millis(50));
813        assert!(!result.success);
814        assert_eq!(result.error, Some("Something went wrong".to_string()));
815        assert_eq!(result.retries, 1);
816        assert!(result.metadata.is_empty());
817    }
818
819    #[test]
820    fn test_execution_result_with_metadata() {
821        let mut result = ExecutionResult::success(
822            "test_tool".to_string(),
823            serde_json::json!({"value": 42}),
824            Duration::from_millis(100),
825            0,
826        );
827
828        result = result.with_metadata("key1", serde_json::json!("value1"));
829        result = result.with_metadata("key2", serde_json::json!({"nested": true}));
830
831        assert_eq!(result.metadata.len(), 2);
832        assert_eq!(
833            result.metadata.get("key1"),
834            Some(&serde_json::json!("value1"))
835        );
836        assert_eq!(
837            result.metadata.get("key2"),
838            Some(&serde_json::json!({"nested": true}))
839        );
840    }
841
842    #[test]
843    fn test_execution_result_serialization() {
844        let result = ExecutionResult::success(
845            "test_tool".to_string(),
846            serde_json::json!({"value": 42}),
847            Duration::from_millis(100),
848            0,
849        );
850
851        let json = serde_json::to_string(&result).unwrap();
852        assert!(json.contains("\"tool_name\":\"test_tool\""));
853        assert!(json.contains("\"success\":true"));
854        assert!(json.contains("\"value\":42"));
855    }
856
857    #[test]
858    fn test_tool_executor_default() {
859        let executor = ToolExecutor::new();
860        assert_eq!(executor.tool_names().len(), 0);
861        assert_eq!(executor.config.timeout, Some(Duration::from_secs(30)));
862    }
863
864    #[test]
865    fn test_tool_executor_register_and_unregister() {
866        let executor = ToolExecutor::new();
867
868        // Create a simple test tool
869        let tool = FunctionTool::builder("test_tool", "A test tool")
870            .handler(|_args| async move { Ok(serde_json::json!({"result": "success"})) })
871            .build()
872            .unwrap();
873
874        // Register the tool
875        executor.add_dyn_tool(Box::new(tool));
876        assert_eq!(executor.tool_names().len(), 1);
877        assert!(executor.has_tool("test_tool"));
878
879        // Unregister the tool
880        assert!(executor.unregister("test_tool").is_ok());
881        assert_eq!(executor.tool_names().len(), 0);
882        assert!(!executor.has_tool("test_tool"));
883    }
884
885    #[test]
886    fn test_tool_executor_duplicate_tool_panics() {
887        let executor = ToolExecutor::new();
888
889        let tool1 = FunctionTool::builder("duplicate_tool", "First tool")
890            .handler(|_args| async move { Ok(serde_json::json!({})) })
891            .build()
892            .unwrap();
893
894        let tool2 = FunctionTool::builder("duplicate_tool", "Second tool")
895            .handler(|_args| async move { Ok(serde_json::json!({})) })
896            .build()
897            .unwrap();
898
899        executor.add_dyn_tool(Box::new(tool1));
900
901        // Adding duplicate tool should panic
902        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
903            executor.add_dyn_tool(Box::new(tool2));
904        }));
905        assert!(result.is_err());
906    }
907
908    #[test]
909    fn test_tool_executor_try_add_dyn_tool() {
910        let executor = ToolExecutor::new();
911
912        let tool1 = FunctionTool::builder("test_tool", "First tool")
913            .handler(|_args| async move { Ok(serde_json::json!({})) })
914            .build()
915            .unwrap();
916
917        let tool2 = FunctionTool::builder("test_tool", "Second tool")
918            .handler(|_args| async move { Ok(serde_json::json!({})) })
919            .build()
920            .unwrap();
921
922        executor.try_add_dyn_tool(Box::new(tool1));
923        executor.try_add_dyn_tool(Box::new(tool2));
924
925        // Only one tool should be registered (second should be ignored)
926        assert_eq!(executor.tool_names().len(), 1);
927        assert!(executor.has_tool("test_tool"));
928    }
929
930    #[test]
931    fn test_tool_executor_unregister_nonexistent_tool() {
932        let executor = ToolExecutor::new();
933        let result = executor.unregister("nonexistent_tool");
934        assert!(result.is_err());
935    }
936
937    #[test]
938    fn test_tool_executor_input_schema() {
939        let executor = ToolExecutor::new();
940
941        let schema = serde_json::json!({
942            "type": "object",
943            "properties": {
944                "name": {"type": "string"}
945            }
946        });
947
948        let tool = FunctionTool::builder("test_tool", "A test tool")
949            .schema(schema.clone())
950            .handler(|_args| async move { Ok(serde_json::json!({})) })
951            .build()
952            .unwrap();
953
954        executor.add_dyn_tool(Box::new(tool));
955
956        let retrieved_schema = executor.input_schema("test_tool");
957        assert!(retrieved_schema.is_some());
958        let retrieved = retrieved_schema.unwrap();
959
960        // Check that schema contains expected properties
961        assert_eq!(retrieved["type"], "object");
962        assert_eq!(retrieved["properties"]["name"]["type"], "string");
963        // additionalProperties is automatically set by FunctionToolBuilder
964        assert_eq!(retrieved["additionalProperties"], false);
965    }
966
967    #[test]
968    fn test_tool_executor_input_schema_nonexistent() {
969        let executor = ToolExecutor::new();
970        let schema = executor.input_schema("nonexistent");
971        assert!(schema.is_none());
972    }
973
974    #[test]
975    fn test_tool_executor_tool_names() {
976        let executor = ToolExecutor::new();
977
978        let tool1 = FunctionTool::builder("tool1", "First tool")
979            .handler(|_args| async move { Ok(serde_json::json!({})) })
980            .build()
981            .unwrap();
982
983        let tool2 = FunctionTool::builder("tool2", "Second tool")
984            .handler(|_args| async move { Ok(serde_json::json!({})) })
985            .build()
986            .unwrap();
987
988        let tool3 = FunctionTool::builder("tool3", "Third tool")
989            .handler(|_args| async move { Ok(serde_json::json!({})) })
990            .build()
991            .unwrap();
992
993        executor.add_dyn_tool(Box::new(tool1));
994        executor.add_dyn_tool(Box::new(tool2));
995        executor.add_dyn_tool(Box::new(tool3));
996
997        let names = executor.tool_names();
998        assert_eq!(names.len(), 3);
999        assert!(names.contains(&"tool1".to_string()));
1000        assert!(names.contains(&"tool2".to_string()));
1001        assert!(names.contains(&"tool3".to_string()));
1002    }
1003
1004    #[tokio::test]
1005    async fn test_tool_executor_execute_success() {
1006        let executor = ToolExecutor::new();
1007
1008        let tool = FunctionTool::builder("add_tool", "Add two numbers")
1009            .property("a", serde_json::json!({"type": "number"}))
1010            .property("b", serde_json::json!({"type": "number"}))
1011            .handler(|args| async move {
1012                let a = args.get("a").and_then(|v| v.as_i64()).unwrap_or(0);
1013                let b = args.get("b").and_then(|v| v.as_i64()).unwrap_or(0);
1014                Ok(serde_json::json!({"result": a + b}))
1015            })
1016            .build()
1017            .unwrap();
1018
1019        executor.add_dyn_tool(Box::new(tool));
1020
1021        let input = serde_json::json!({"a": 5, "b": 3});
1022        let result = executor.execute("add_tool", input).await.unwrap();
1023
1024        assert!(result.success);
1025        assert_eq!(result.tool_name, "add_tool");
1026        assert_eq!(result.result, serde_json::json!({"result": 8}));
1027        assert_eq!(result.retries, 0);
1028    }
1029
1030    #[tokio::test]
1031    async fn test_tool_executor_execute_failure() {
1032        let executor = ToolExecutor::new();
1033
1034        let tool = FunctionTool::builder("failing_tool", "Always fails")
1035            .handler(|_args| async move {
1036                Err(error_context()
1037                    .with_tool("failing_tool")
1038                    .execution_failed("Intentional failure"))
1039            })
1040            .build()
1041            .unwrap();
1042
1043        executor.add_dyn_tool(Box::new(tool));
1044
1045        let input = serde_json::json!({});
1046        let result = executor.execute("failing_tool", input).await.unwrap();
1047
1048        assert!(!result.success);
1049        assert_eq!(result.tool_name, "failing_tool");
1050        assert!(result.error.is_some());
1051    }
1052
1053    #[tokio::test]
1054    async fn test_tool_executor_execute_nonexistent_tool() {
1055        let executor = ToolExecutor::new();
1056        let input = serde_json::json!({});
1057        let result = executor.execute("nonexistent_tool", input).await.unwrap();
1058
1059        assert!(!result.success);
1060        assert!(result.error.is_some());
1061    }
1062
1063    #[tokio::test]
1064    async fn test_tool_executor_execute_simple_success() {
1065        let executor = ToolExecutor::new();
1066
1067        let tool = FunctionTool::builder("echo_tool", "Echo input")
1068            .property("message", serde_json::json!({"type": "string"}))
1069            .handler(|args| async move { Ok(args) })
1070            .build()
1071            .unwrap();
1072
1073        executor.add_dyn_tool(Box::new(tool));
1074
1075        let input = serde_json::json!({"message": "hello"});
1076        let result = executor.execute_simple("echo_tool", input).await.unwrap();
1077
1078        assert_eq!(result, serde_json::json!({"message": "hello"}));
1079    }
1080
1081    #[tokio::test]
1082    async fn test_tool_executor_execute_simple_failure() {
1083        let executor = ToolExecutor::new();
1084
1085        let tool = FunctionTool::builder("failing_tool", "Always fails")
1086            .handler(|_args| async move {
1087                Err(error_context()
1088                    .with_tool("failing_tool")
1089                    .execution_failed("Intentional failure"))
1090            })
1091            .build()
1092            .unwrap();
1093
1094        executor.add_dyn_tool(Box::new(tool));
1095
1096        let input = serde_json::json!({});
1097        let result = executor.execute_simple("failing_tool", input).await;
1098
1099        assert!(result.is_err());
1100    }
1101
1102    #[tokio::test]
1103    async fn test_tool_executor_timeout() {
1104        let executor = ToolExecutor::builder()
1105            .timeout(Duration::from_millis(100))
1106            .build();
1107
1108        let tool = FunctionTool::builder("slow_tool", "Slow tool")
1109            .handler(|_args| async move {
1110                tokio::time::sleep(Duration::from_secs(1)).await;
1111                Ok(serde_json::json!({"done": true}))
1112            })
1113            .build()
1114            .unwrap();
1115
1116        executor.add_dyn_tool(Box::new(tool));
1117
1118        let input = serde_json::json!({});
1119        let result = executor.execute("slow_tool", input).await.unwrap();
1120
1121        assert!(!result.success);
1122        assert!(result.error.is_some());
1123        assert!(result.error.unwrap().contains("Timeout"));
1124    }
1125
1126    #[tokio::test]
1127    async fn test_tool_executor_retry() {
1128        let executor = ToolExecutor::builder()
1129            .retries(2)
1130            .timeout(Duration::from_secs(30))
1131            .build();
1132
1133        let attempt_counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
1134        let counter_clone = attempt_counter.clone();
1135
1136        let tool = FunctionTool::builder("flaky_tool", "Flaky tool")
1137            .handler(move |_args| {
1138                let counter = counter_clone.clone();
1139                async move {
1140                    let attempts = counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1141                    if attempts < 2 {
1142                        Err(error_context()
1143                            .with_tool("flaky_tool")
1144                            .execution_failed("Temporary failure"))
1145                    } else {
1146                        Ok(serde_json::json!({"attempts": attempts + 1}))
1147                    }
1148                }
1149            })
1150            .build()
1151            .unwrap();
1152
1153        executor.add_dyn_tool(Box::new(tool));
1154
1155        let input = serde_json::json!({});
1156        let result = executor.execute("flaky_tool", input).await.unwrap();
1157
1158        assert!(result.success);
1159        assert_eq!(result.retries, 2);
1160    }
1161
1162    #[test]
1163    fn test_executor_builder_default() {
1164        let builder = ExecutorBuilder::new();
1165        assert_eq!(builder.config.timeout, Some(Duration::from_secs(30)));
1166        assert_eq!(builder.config.retry_config.max_retries, 3);
1167    }
1168
1169    #[test]
1170    fn test_executor_builder_timeout() {
1171        let builder = ExecutorBuilder::new().timeout(Duration::from_secs(60));
1172        assert_eq!(builder.config.timeout, Some(Duration::from_secs(60)));
1173    }
1174
1175    #[test]
1176    fn test_executor_builder_retries() {
1177        let builder = ExecutorBuilder::new().retries(5);
1178        assert_eq!(builder.config.retry_config.max_retries, 5);
1179    }
1180
1181    #[test]
1182    fn test_executor_builder_logging() {
1183        let builder = ExecutorBuilder::new().logging(true);
1184        assert!(builder.config.enable_logging);
1185    }
1186
1187    #[test]
1188    fn test_executor_builder_build() {
1189        let executor = ExecutorBuilder::new()
1190            .timeout(Duration::from_secs(60))
1191            .retries(5)
1192            .logging(true)
1193            .build();
1194
1195        assert_eq!(executor.config.timeout, Some(Duration::from_secs(60)));
1196        assert_eq!(executor.config.retry_config.max_retries, 5);
1197        assert!(executor.config.enable_logging);
1198    }
1199
1200    #[test]
1201    fn test_executor_builder_chainable() {
1202        let builder = ExecutorBuilder::new()
1203            .timeout(Duration::from_secs(45))
1204            .retries(3)
1205            .logging(false)
1206            .timeout(Duration::from_secs(50))
1207            .retries(4)
1208            .logging(true);
1209
1210        assert_eq!(builder.config.timeout, Some(Duration::from_secs(50)));
1211        assert_eq!(builder.config.retry_config.max_retries, 4);
1212        assert!(builder.config.enable_logging);
1213    }
1214
1215    #[test]
1216    fn test_export_tool_as_function() {
1217        let executor = ToolExecutor::new();
1218
1219        let tool = FunctionTool::builder("greet_tool", "Greet someone")
1220            .handler(|_args| async move { Ok(serde_json::json!({"greeting": "hello"})) })
1221            .build()
1222            .unwrap();
1223
1224        executor.add_dyn_tool(Box::new(tool));
1225
1226        let exported = executor.export_tool_as_function("greet_tool");
1227        assert!(exported.is_some());
1228
1229        if let Some(Tools::Function { function }) = exported {
1230            assert_eq!(function.name, "greet_tool");
1231            assert_eq!(function.description, "Greet someone");
1232            // Schema is auto-generated with default values
1233            assert!(function.parameters.is_some());
1234        } else {
1235            panic!("Expected Tools::Function");
1236        }
1237    }
1238
1239    #[test]
1240    fn test_export_tool_as_function_nonexistent() {
1241        let executor = ToolExecutor::new();
1242        let exported = executor.export_tool_as_function("nonexistent");
1243        assert!(exported.is_none());
1244    }
1245
1246    #[test]
1247    fn test_export_all_tools_as_functions() {
1248        let executor = ToolExecutor::new();
1249
1250        let tool1 = FunctionTool::builder("tool1", "First tool")
1251            .handler(|_args| async move { Ok(serde_json::json!({})) })
1252            .build()
1253            .unwrap();
1254
1255        let tool2 = FunctionTool::builder("tool2", "Second tool")
1256            .handler(|_args| async move { Ok(serde_json::json!({})) })
1257            .build()
1258            .unwrap();
1259
1260        executor.add_dyn_tool(Box::new(tool1));
1261        executor.add_dyn_tool(Box::new(tool2));
1262
1263        let exported = executor.export_all_tools_as_functions();
1264        assert_eq!(exported.len(), 2);
1265
1266        let names: Vec<_> = exported
1267            .iter()
1268            .filter_map(|t| match t {
1269                Tools::Function { function } => Some(function.name.clone()),
1270                _ => None,
1271            })
1272            .collect();
1273
1274        assert!(names.contains(&"tool1".to_string()));
1275        assert!(names.contains(&"tool2".to_string()));
1276    }
1277
1278    #[test]
1279    fn test_export_tools_filtered() {
1280        let executor = ToolExecutor::new();
1281
1282        let tool1 = FunctionTool::builder("math_tool", "Math operations")
1283            .metadata(|m| m.version("1.0.0"))
1284            .handler(|_args| async move { Ok(serde_json::json!({})) })
1285            .build()
1286            .unwrap();
1287
1288        let tool2 = FunctionTool::builder("text_tool", "Text operations")
1289            .metadata(|m| m.version("2.0.0"))
1290            .handler(|_args| async move { Ok(serde_json::json!({})) })
1291            .build()
1292            .unwrap();
1293
1294        executor.add_dyn_tool(Box::new(tool1));
1295        executor.add_dyn_tool(Box::new(tool2));
1296
1297        let exported = executor.export_tools_filtered(|meta| meta.version == "1.0.0");
1298        assert_eq!(exported.len(), 1);
1299
1300        if let Some(Tools::Function { function }) = exported.first() {
1301            assert_eq!(function.name, "math_tool");
1302        } else {
1303            panic!("Expected Tools::Function");
1304        }
1305    }
1306
1307    #[test]
1308    fn test_execution_result_metadata_serialization() {
1309        let result = ExecutionResult::success(
1310            "test_tool".to_string(),
1311            serde_json::json!({"value": 42}),
1312            Duration::from_millis(100),
1313            0,
1314        )
1315        .with_metadata("key1", serde_json::json!("value1"))
1316        .with_metadata("key2", serde_json::json!(123));
1317
1318        let json = serde_json::to_string(&result).unwrap();
1319        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1320
1321        assert_eq!(parsed["metadata"]["key1"], "value1");
1322        assert_eq!(parsed["metadata"]["key2"], 123);
1323    }
1324
1325    #[test]
1326    fn test_execution_result_timestamp() {
1327        let before = std::time::SystemTime::now();
1328        let result = ExecutionResult::success(
1329            "test_tool".to_string(),
1330            serde_json::json!({"value": 42}),
1331            Duration::from_millis(100),
1332            0,
1333        );
1334        let after = std::time::SystemTime::now();
1335
1336        assert!(result.timestamp >= before && result.timestamp <= after);
1337    }
1338}