tool_useful/
executor.rs

1//! High-performance tool execution engine with retry, timeout, and resource management.
2
3use crate::{ErrorKind, ToolError, ToolResult};
4use async_trait::async_trait;
5use futures::stream::{self, StreamExt};
6use parking_lot::RwLock;
7use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{watch, Semaphore};
11use tracing::{debug, instrument, warn};
12
13/// Context provided during tool execution
14#[derive(Clone)]
15pub struct ExecutionContext {
16    cancellation: Arc<watch::Receiver<bool>>,
17    pub timeout: Option<Duration>,
18    pub max_memory: Option<usize>,
19    pub metadata: Arc<RwLock<ExecutionMetadata>>,
20    started_at: Instant,
21}
22
23impl ExecutionContext {
24    pub fn new() -> Self {
25        let (tx, rx) = watch::channel(false);
26        std::mem::drop(tx);
27        Self {
28            cancellation: Arc::new(rx),
29            timeout: None,
30            max_memory: None,
31            metadata: Arc::new(RwLock::new(ExecutionMetadata::default())),
32            started_at: Instant::now(),
33        }
34    }
35
36    pub fn with_cancellation(cancellation: watch::Receiver<bool>) -> Self {
37        Self {
38            cancellation: Arc::new(cancellation),
39            timeout: None,
40            max_memory: None,
41            metadata: Arc::new(RwLock::new(ExecutionMetadata::default())),
42            started_at: Instant::now(),
43        }
44    }
45
46    pub fn with_timeout(mut self, timeout: Duration) -> Self {
47        self.timeout = Some(timeout);
48        self
49    }
50
51    pub fn elapsed(&self) -> Duration {
52        self.started_at.elapsed()
53    }
54
55    pub fn is_cancelled(&self) -> bool {
56        *self.cancellation.borrow()
57    }
58
59    pub fn check_cancelled(&self) -> ToolResult<()> {
60        if self.is_cancelled() {
61            Err(ToolError::Cancelled)
62        } else {
63            Ok(())
64        }
65    }
66
67    pub fn set_metadata<V: serde::Serialize>(&self, key: impl Into<String>, value: V) {
68        if let Ok(v) = serde_json::to_value(value) {
69            self.metadata.write().fields.insert(key.into(), v);
70        }
71    }
72
73    pub fn get_metadata<T: for<'de> serde::Deserialize<'de>>(&self, key: &str) -> Option<T> {
74        self.metadata
75            .read()
76            .fields
77            .get(key)
78            .and_then(|v| serde_json::from_value(v.clone()).ok())
79    }
80}
81
82impl Default for ExecutionContext {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88#[derive(Default)]
89pub struct ExecutionMetadata {
90    pub fields: std::collections::HashMap<String, serde_json::Value>,
91}
92
93/// Trait for types that can execute as tools
94#[async_trait]
95pub trait ToolExecutor: Send + Sync {
96    type Output: serde::Serialize + Send;
97    type Error: std::error::Error + Send + Sync + 'static;
98
99    async fn execute(&self, ctx: &ExecutionContext) -> Result<Self::Output, Self::Error>;
100
101    async fn execute_tool(&self, ctx: &ExecutionContext) -> ToolResult<Self::Output> {
102        self.execute(ctx).await.map_err(|e| ToolError::custom(e))
103    }
104}
105
106/// Retry policy with circuit breaker support
107#[derive(Debug, Clone)]
108pub struct RetryPolicy {
109    pub max_attempts: u32,
110    pub base_delay: Duration,
111    pub max_delay: Duration,
112    pub strategy: RetryStrategy,
113    pub retryable_errors: Vec<ErrorKind>,
114    pub jitter: bool,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum RetryStrategy {
119    Fixed,
120    Exponential,
121    Linear,
122}
123
124impl RetryPolicy {
125    pub fn exponential(max_attempts: u32) -> Self {
126        Self {
127            max_attempts,
128            base_delay: Duration::from_millis(100),
129            max_delay: Duration::from_secs(30),
130            strategy: RetryStrategy::Exponential,
131            retryable_errors: vec![ErrorKind::Network, ErrorKind::Timeout, ErrorKind::Resource],
132            jitter: true,
133        }
134    }
135
136    pub fn fixed(max_attempts: u32, delay: Duration) -> Self {
137        Self {
138            max_attempts,
139            base_delay: delay,
140            max_delay: delay,
141            strategy: RetryStrategy::Fixed,
142            retryable_errors: vec![ErrorKind::Network, ErrorKind::Timeout, ErrorKind::Resource],
143            jitter: false,
144        }
145    }
146
147    pub fn with_backoff(mut self, delay: Duration) -> Self {
148        self.base_delay = delay;
149        self
150    }
151
152    pub fn with_max_delay(mut self, delay: Duration) -> Self {
153        self.max_delay = delay;
154        self
155    }
156
157    pub fn with_jitter(mut self, jitter: bool) -> Self {
158        self.jitter = jitter;
159        self
160    }
161
162    pub fn should_retry(&self, error: &ToolError) -> bool {
163        self.retryable_errors.contains(&error.kind())
164    }
165
166    pub fn calculate_backoff(&self, attempt: u32) -> Duration {
167        let delay = match self.strategy {
168            RetryStrategy::Fixed => self.base_delay,
169            RetryStrategy::Exponential => {
170                let multiplier = 2u32.pow(attempt.saturating_sub(1));
171                self.base_delay.saturating_mul(multiplier)
172            }
173            RetryStrategy::Linear => self.base_delay.saturating_mul(attempt),
174        };
175
176        let delay = delay.min(self.max_delay);
177
178        if self.jitter {
179            // Add jitter: random value between 0.5 and 1.5 times the delay
180            let jitter_factor = 0.5 + (rand::random::<f64>() * 1.0);
181            Duration::from_secs_f64(delay.as_secs_f64() * jitter_factor)
182        } else {
183            delay
184        }
185    }
186}
187
188/// Circuit breaker to prevent cascading failures
189#[derive(Debug, Clone)]
190pub struct CircuitBreaker {
191    failure_threshold: u32,
192    success_threshold: u32,
193    timeout: Duration,
194    state: Arc<RwLock<CircuitBreakerState>>,
195    failures: Arc<AtomicU64>,
196    successes: Arc<AtomicU64>,
197    last_failure_time: Arc<RwLock<Option<Instant>>>,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201enum CircuitBreakerState {
202    Closed,
203    Open,
204    HalfOpen,
205}
206
207impl CircuitBreaker {
208    pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
209        Self {
210            failure_threshold,
211            success_threshold: 2,
212            timeout,
213            state: Arc::new(RwLock::new(CircuitBreakerState::Closed)),
214            failures: Arc::new(AtomicU64::new(0)),
215            successes: Arc::new(AtomicU64::new(0)),
216            last_failure_time: Arc::new(RwLock::new(None)),
217        }
218    }
219
220    pub fn call<F, Fut, T>(&self, f: F) -> impl std::future::Future<Output = ToolResult<T>>
221    where
222        F: FnOnce() -> Fut,
223        Fut: std::future::Future<Output = ToolResult<T>>,
224    {
225        let state = *self.state.read();
226        let should_attempt = match state {
227            CircuitBreakerState::Open => {
228                if let Some(last_failure) = *self.last_failure_time.read() {
229                    last_failure.elapsed() > self.timeout
230                } else {
231                    false
232                }
233            }
234            _ => true,
235        };
236
237        let failures = self.failures.clone();
238        let successes = self.successes.clone();
239        let state_arc = self.state.clone();
240        let last_failure = self.last_failure_time.clone();
241        let failure_threshold = self.failure_threshold;
242        let success_threshold = self.success_threshold;
243
244        async move {
245            if !should_attempt {
246                return Err(ToolError::execution_failed("Circuit breaker is open"));
247            }
248
249            match f().await {
250                Ok(result) => {
251                    successes.fetch_add(1, Ordering::Relaxed);
252                    let success_count = successes.load(Ordering::Relaxed);
253
254                    if success_count >= success_threshold as u64 {
255                        *state_arc.write() = CircuitBreakerState::Closed;
256                        failures.store(0, Ordering::Relaxed);
257                        successes.store(0, Ordering::Relaxed);
258                    }
259
260                    Ok(result)
261                }
262                Err(err) => {
263                    failures.fetch_add(1, Ordering::Relaxed);
264                    *last_failure.write() = Some(Instant::now());
265
266                    if failures.load(Ordering::Relaxed) >= failure_threshold as u64 {
267                        *state_arc.write() = CircuitBreakerState::Open;
268                    }
269
270                    Err(err)
271                }
272            }
273        }
274    }
275}
276
277/// High-performance executor with advanced features
278#[derive(Clone)]
279pub struct Executor {
280    config: Arc<ExecutorConfig>,
281    semaphore: Arc<Semaphore>,
282    metrics: Arc<ExecutorMetrics>,
283    circuit_breaker: Option<Arc<CircuitBreaker>>,
284}
285
286#[derive(Debug)]
287struct ExecutorConfig {
288    default_timeout: Option<Duration>,
289    max_concurrent: usize,
290    retry_policy: Option<RetryPolicy>,
291    enable_tracing: bool,
292}
293
294impl Default for ExecutorConfig {
295    fn default() -> Self {
296        Self {
297            default_timeout: Some(Duration::from_secs(30)),
298            max_concurrent: 100,
299            retry_policy: None,
300            enable_tracing: true,
301        }
302    }
303}
304
305#[derive(Debug, Default)]
306pub struct ExecutorMetrics {
307    pub total_executions: AtomicUsize,
308    pub successful_executions: AtomicUsize,
309    pub failed_executions: AtomicUsize,
310    pub total_duration_ms: AtomicU64,
311}
312
313impl ExecutorMetrics {
314    pub fn success_rate(&self) -> f64 {
315        let total = self.total_executions.load(Ordering::Relaxed);
316        if total == 0 {
317            return 0.0;
318        }
319        let successful = self.successful_executions.load(Ordering::Relaxed);
320        (successful as f64 / total as f64) * 100.0
321    }
322
323    pub fn avg_duration_ms(&self) -> f64 {
324        let total = self.total_executions.load(Ordering::Relaxed);
325        if total == 0 {
326            return 0.0;
327        }
328        let duration = self.total_duration_ms.load(Ordering::Relaxed);
329        duration as f64 / total as f64
330    }
331}
332
333impl Executor {
334    pub fn new() -> Self {
335        let config = ExecutorConfig::default();
336        let max_concurrent = config.max_concurrent;
337        Self {
338            config: Arc::new(config),
339            semaphore: Arc::new(Semaphore::new(max_concurrent)),
340            metrics: Arc::new(ExecutorMetrics::default()),
341            circuit_breaker: None,
342        }
343    }
344
345    pub fn builder() -> ExecutorBuilder {
346        ExecutorBuilder::new()
347    }
348
349    pub fn metrics(&self) -> &ExecutorMetrics {
350        &self.metrics
351    }
352
353    #[instrument(skip(self, tool))]
354    pub async fn execute<T>(&self, tool: &T) -> ToolResult<T::Output>
355    where
356        T: ToolExecutor,
357    {
358        let ctx = ExecutionContext::new();
359        self.execute_with_context(tool, &ctx).await
360    }
361
362    pub async fn execute_with_context<T>(
363        &self,
364        tool: &T,
365        ctx: &ExecutionContext,
366    ) -> ToolResult<T::Output>
367    where
368        T: ToolExecutor,
369    {
370        // Acquire semaphore permit for concurrency control
371        let _permit = self
372            .semaphore
373            .acquire()
374            .await
375            .map_err(|_| ToolError::execution_failed("Failed to acquire execution permit"))?;
376
377        let start = Instant::now();
378        self.metrics
379            .total_executions
380            .fetch_add(1, Ordering::Relaxed);
381
382        let result = if let Some(ref cb) = self.circuit_breaker {
383            cb.call(|| self.execute_internal(tool, ctx)).await
384        } else {
385            self.execute_internal(tool, ctx).await
386        };
387
388        let duration = start.elapsed();
389        self.metrics
390            .total_duration_ms
391            .fetch_add(duration.as_millis() as u64, Ordering::Relaxed);
392
393        match &result {
394            Ok(_) => {
395                self.metrics
396                    .successful_executions
397                    .fetch_add(1, Ordering::Relaxed);
398                debug!("Tool execution succeeded in {:?}", duration);
399            }
400            Err(e) => {
401                self.metrics
402                    .failed_executions
403                    .fetch_add(1, Ordering::Relaxed);
404                warn!("Tool execution failed: {} (duration: {:?})", e, duration);
405            }
406        }
407
408        result
409    }
410
411    async fn execute_internal<T>(&self, tool: &T, ctx: &ExecutionContext) -> ToolResult<T::Output>
412    where
413        T: ToolExecutor,
414    {
415        let timeout = ctx.timeout.or(self.config.default_timeout);
416
417        if let Some(ref retry_policy) = self.config.retry_policy {
418            self.execute_with_retry(tool, ctx, retry_policy, timeout)
419                .await
420        } else if let Some(timeout_duration) = timeout {
421            self.execute_with_timeout(tool, ctx, timeout_duration).await
422        } else {
423            tool.execute_tool(ctx).await
424        }
425    }
426
427    async fn execute_with_timeout<T>(
428        &self,
429        tool: &T,
430        ctx: &ExecutionContext,
431        timeout: Duration,
432    ) -> ToolResult<T::Output>
433    where
434        T: ToolExecutor,
435    {
436        tokio::time::timeout(timeout, tool.execute_tool(ctx))
437            .await
438            .map_err(|_| ToolError::Timeout(timeout))?
439    }
440
441    async fn execute_with_retry<T>(
442        &self,
443        tool: &T,
444        ctx: &ExecutionContext,
445        policy: &RetryPolicy,
446        timeout: Option<Duration>,
447    ) -> ToolResult<T::Output>
448    where
449        T: ToolExecutor,
450    {
451        let mut attempts = 0;
452        let mut last_error = None;
453
454        while attempts <= policy.max_attempts {
455            let result = if let Some(timeout_duration) = timeout {
456                self.execute_with_timeout(tool, ctx, timeout_duration).await
457            } else {
458                tool.execute_tool(ctx).await
459            };
460
461            match result {
462                Ok(output) => return Ok(output),
463                Err(err) => {
464                    attempts += 1;
465                    if !policy.should_retry(&err) || attempts > policy.max_attempts {
466                        return Err(err);
467                    }
468                    last_error = Some(err);
469                    let delay = policy.calculate_backoff(attempts);
470                    debug!("Retrying after {:?} (attempt {})", delay, attempts);
471                    tokio::time::sleep(delay).await;
472                }
473            }
474        }
475
476        Err(last_error
477            .unwrap_or_else(|| ToolError::execution_failed("Max retry attempts exceeded")))
478    }
479
480    /// Execute multiple tools in parallel
481    pub async fn execute_batch<T>(&self, tools: Vec<T>) -> Vec<ToolResult<T::Output>>
482    where
483        T: ToolExecutor + Clone,
484    {
485        stream::iter(tools)
486            .map(|tool| async move { self.execute(&tool).await })
487            .buffer_unordered(self.config.max_concurrent)
488            .collect()
489            .await
490    }
491}
492
493impl Default for Executor {
494    fn default() -> Self {
495        Self::new()
496    }
497}
498
499/// Builder for creating configured executors
500#[derive(Default)]
501pub struct ExecutorBuilder {
502    config: ExecutorConfig,
503    circuit_breaker: Option<CircuitBreaker>,
504}
505
506impl ExecutorBuilder {
507    pub fn new() -> Self {
508        Self {
509            config: ExecutorConfig::default(),
510            circuit_breaker: None,
511        }
512    }
513
514    pub fn timeout(mut self, timeout: Duration) -> Self {
515        self.config.default_timeout = Some(timeout);
516        self
517    }
518
519    pub fn max_concurrent(mut self, max: usize) -> Self {
520        self.config.max_concurrent = max;
521        self
522    }
523
524    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
525        self.config.retry_policy = Some(policy);
526        self
527    }
528
529    pub fn circuit_breaker(mut self, failure_threshold: u32, timeout: Duration) -> Self {
530        self.circuit_breaker = Some(CircuitBreaker::new(failure_threshold, timeout));
531        self
532    }
533
534    pub fn enable_tracing(mut self, enable: bool) -> Self {
535        self.config.enable_tracing = enable;
536        self
537    }
538
539    pub fn build(self) -> Executor {
540        let max_concurrent = self.config.max_concurrent;
541        Executor {
542            config: Arc::new(self.config),
543            semaphore: Arc::new(Semaphore::new(max_concurrent)),
544            metrics: Arc::new(ExecutorMetrics::default()),
545            circuit_breaker: self.circuit_breaker.map(Arc::new),
546        }
547    }
548}
549
550// Add random number generation for jitter
551mod rand {
552    use std::cell::Cell;
553
554    thread_local! {
555        static RNG: Cell<u64> = Cell::new(0x4d595df4d0f33173);
556    }
557
558    pub fn random<T: SampleUniform>() -> T {
559        T::sample_uniform()
560    }
561
562    pub trait SampleUniform: Sized {
563        fn sample_uniform() -> Self;
564    }
565
566    impl SampleUniform for f64 {
567        fn sample_uniform() -> Self {
568            RNG.with(|rng| {
569                let mut x = rng.get();
570                x ^= x << 13;
571                x ^= x >> 7;
572                x ^= x << 17;
573                rng.set(x);
574                (x as f64) / (u64::MAX as f64)
575            })
576        }
577    }
578}