turbomcp_core/
utils.rs

1//! Utility functions and helper macros.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use pin_project_lite::pin_project;
9use tokio::time::{Sleep, sleep};
10
11pin_project! {
12    /// Timeout wrapper for futures
13    pub struct Timeout<F> {
14        #[pin]
15        future: F,
16        #[pin]
17        delay: Sleep,
18    }
19}
20
21impl<F> Timeout<F> {
22    /// Create a new timeout wrapper
23    pub fn new(future: F, duration: Duration) -> Self {
24        Self {
25            future,
26            delay: sleep(duration),
27        }
28    }
29}
30
31impl<F> Future for Timeout<F>
32where
33    F: Future,
34{
35    type Output = Result<F::Output, TimeoutError>;
36
37    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
38        let this = self.project();
39
40        // First check if the future is ready
41        if let Poll::Ready(output) = this.future.poll(cx) {
42            return Poll::Ready(Ok(output));
43        }
44
45        // Then check if the timeout has expired
46        match this.delay.poll(cx) {
47            Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
48            Poll::Pending => Poll::Pending,
49        }
50    }
51}
52
53/// Timeout error
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct TimeoutError;
56
57impl std::fmt::Display for TimeoutError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        write!(f, "Operation timed out")
60    }
61}
62
63impl std::error::Error for TimeoutError {}
64
65/// Utility function to create a timeout future
66pub fn timeout<F>(duration: Duration, future: F) -> Timeout<F>
67where
68    F: Future,
69{
70    Timeout::new(future, duration)
71}
72
73/// Retry configuration
74#[derive(Debug, Clone)]
75pub struct RetryConfig {
76    /// Maximum number of attempts
77    pub max_attempts: usize,
78    /// Base delay between attempts
79    pub base_delay: Duration,
80    /// Maximum delay between attempts
81    pub max_delay: Duration,
82    /// Backoff multiplier
83    pub backoff_multiplier: f64,
84    /// Whether to add jitter
85    pub jitter: bool,
86}
87
88impl Default for RetryConfig {
89    fn default() -> Self {
90        Self {
91            max_attempts: 3,
92            base_delay: Duration::from_millis(100),
93            max_delay: Duration::from_secs(30),
94            backoff_multiplier: 2.0,
95            jitter: true,
96        }
97    }
98}
99
100impl RetryConfig {
101    /// Create a new retry configuration
102    #[must_use]
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Set maximum attempts
108    #[must_use]
109    pub const fn with_max_attempts(mut self, max_attempts: usize) -> Self {
110        self.max_attempts = max_attempts;
111        self
112    }
113
114    /// Set base delay
115    #[must_use]
116    pub const fn with_base_delay(mut self, delay: Duration) -> Self {
117        self.base_delay = delay;
118        self
119    }
120
121    /// Set maximum delay
122    #[must_use]
123    pub const fn with_max_delay(mut self, delay: Duration) -> Self {
124        self.max_delay = delay;
125        self
126    }
127
128    /// Set backoff multiplier
129    #[must_use]
130    pub const fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
131        self.backoff_multiplier = multiplier;
132        self
133    }
134
135    /// Enable or disable jitter
136    #[must_use]
137    pub const fn with_jitter(mut self, jitter: bool) -> Self {
138        self.jitter = jitter;
139        self
140    }
141
142    /// Calculate delay for the given attempt number
143    #[must_use]
144    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
145        if attempt == 0 {
146            return Duration::ZERO;
147        }
148
149        let base_delay_ms = self.base_delay.as_millis() as f64;
150        let multiplier = self.backoff_multiplier.powi((attempt - 1) as i32);
151        let delay_ms = base_delay_ms * multiplier;
152
153        let delay = Duration::from_millis(delay_ms as u64).min(self.max_delay);
154
155        if self.jitter {
156            let jitter_factor = (rand::random::<f64>() - 0.5).mul_add(0.1, 1.0); // ±5% jitter
157            let jittered_delay = delay.mul_f64(jitter_factor);
158            jittered_delay.min(self.max_delay)
159        } else {
160            delay
161        }
162    }
163}
164
165/// Retry a future with exponential backoff
166///
167/// # Panics
168///
169/// Panics if no retry attempts are made and no error is captured
170pub async fn retry_with_backoff<F, Fut, T, E>(
171    mut operation: F,
172    config: RetryConfig,
173    should_retry: impl Fn(&E) -> bool,
174) -> Result<T, E>
175where
176    F: FnMut() -> Fut,
177    Fut: Future<Output = Result<T, E>>,
178{
179    let mut last_error = None;
180
181    for attempt in 0..config.max_attempts {
182        match operation().await {
183            Ok(result) => return Ok(result),
184            Err(error) => {
185                if !should_retry(&error) || attempt + 1 >= config.max_attempts {
186                    return Err(error);
187                }
188
189                let delay = config.delay_for_attempt(attempt + 1);
190                sleep(delay).await;
191                last_error = Some(error);
192            }
193        }
194    }
195
196    // This should never happen since we always set last_error before breaking
197    // But if it does, we need to return some error. Use expect to catch this bug.
198    Err(last_error.expect("Retry loop ended without attempts - this is a bug in retry logic"))
199}
200
201/// Circuit breaker state
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum CircuitState {
204    /// Circuit is closed (normal operation)
205    Closed,
206    /// Circuit is open (failing fast)
207    Open,
208    /// Circuit is half-open (testing recovery)
209    HalfOpen,
210}
211
212/// Simple circuit breaker implementation
213#[derive(Debug)]
214pub struct CircuitBreaker {
215    state: parking_lot::Mutex<CircuitBreakerState>,
216    failure_threshold: usize,
217    recovery_timeout: Duration,
218    success_threshold: usize,
219}
220
221#[derive(Debug)]
222struct CircuitBreakerState {
223    state: CircuitState,
224    failure_count: usize,
225    success_count: usize,
226    last_failure_time: Option<std::time::Instant>,
227}
228
229impl CircuitBreaker {
230    /// Create a new circuit breaker
231    #[must_use]
232    pub const fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
233        Self {
234            state: parking_lot::Mutex::new(CircuitBreakerState {
235                state: CircuitState::Closed,
236                failure_count: 0,
237                success_count: 0,
238                last_failure_time: None,
239            }),
240            failure_threshold,
241            recovery_timeout,
242            success_threshold: 3,
243        }
244    }
245
246    /// Execute an operation through the circuit breaker
247    pub async fn call<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
248    where
249        F: FnOnce() -> Fut,
250        Fut: Future<Output = Result<T, E>>,
251    {
252        // Check if circuit is open
253        if self.is_open() {
254            return Err(CircuitBreakerError::Open);
255        }
256
257        // Execute the operation
258        match operation().await {
259            Ok(result) => {
260                self.record_success();
261                Ok(result)
262            }
263            Err(error) => {
264                self.record_failure();
265                Err(CircuitBreakerError::Operation(error))
266            }
267        }
268    }
269
270    /// Get current circuit state
271    pub fn state(&self) -> CircuitState {
272        self.state.lock().state
273    }
274
275    fn is_open(&self) -> bool {
276        let mut state = self.state.lock();
277
278        match state.state {
279            CircuitState::Open => {
280                // Check if recovery timeout has passed
281                state.last_failure_time.is_none_or(|last_failure| {
282                    if last_failure.elapsed() >= self.recovery_timeout {
283                        state.state = CircuitState::HalfOpen;
284                        state.success_count = 0;
285                        false
286                    } else {
287                        true
288                    }
289                })
290            }
291            _ => false,
292        }
293    }
294
295    fn record_success(&self) {
296        let mut state = self.state.lock();
297
298        match state.state {
299            CircuitState::Closed => {
300                state.failure_count = 0;
301            }
302            CircuitState::HalfOpen => {
303                state.success_count += 1;
304                if state.success_count >= self.success_threshold {
305                    state.state = CircuitState::Closed;
306                    state.failure_count = 0;
307                    state.success_count = 0;
308                }
309            }
310            CircuitState::Open => {
311                // Should not reach here
312            }
313        }
314    }
315
316    fn record_failure(&self) {
317        let mut state = self.state.lock();
318
319        state.failure_count += 1;
320        state.last_failure_time = Some(std::time::Instant::now());
321
322        match state.state {
323            CircuitState::Closed => {
324                if state.failure_count >= self.failure_threshold {
325                    state.state = CircuitState::Open;
326                }
327            }
328            CircuitState::HalfOpen => {
329                state.state = CircuitState::Open;
330                state.success_count = 0;
331            }
332            CircuitState::Open => {
333                // Already open
334            }
335        }
336    }
337}
338
339/// Circuit breaker error
340#[derive(Debug)]
341pub enum CircuitBreakerError<E> {
342    /// Circuit is open
343    Open,
344    /// Operation failed
345    Operation(E),
346}
347
348impl<E> std::fmt::Display for CircuitBreakerError<E>
349where
350    E: std::fmt::Display,
351{
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        match self {
354            Self::Open => write!(f, "Circuit breaker is open"),
355            Self::Operation(e) => write!(f, "Operation failed: {e}"),
356        }
357    }
358}
359
360impl<E> std::error::Error for CircuitBreakerError<E>
361where
362    E: std::error::Error + 'static,
363{
364    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
365        match self {
366            Self::Open => None,
367            Self::Operation(e) => Some(e),
368        }
369    }
370}
371
372/// Utility macro for measuring execution time
373#[macro_export]
374macro_rules! measure_time {
375    ($name:expr, $block:block) => {{
376        let _start = std::time::Instant::now();
377        let result = $block;
378        let _elapsed = _start.elapsed();
379
380        #[cfg(feature = "tracing")]
381        tracing::debug!("{} took {:?}", $name, _elapsed);
382
383        result
384    }};
385}
386
387/// Utility macro for conditional compilation based on features
388#[macro_export]
389macro_rules! feature_gate {
390    ($feature:expr, $block:block) => {
391        #[cfg(feature = $feature)]
392        $block
393    };
394    ($feature:expr, $if_block:block, $else_block:block) => {
395        #[cfg(feature = $feature)]
396        $if_block
397        #[cfg(not(feature = $feature))]
398        $else_block
399    };
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use std::sync::Arc;
406    use std::sync::atomic::{AtomicU32, Ordering};
407
408    #[tokio::test]
409    async fn test_timeout() {
410        // Test successful operation within timeout
411        let result = timeout(Duration::from_millis(100), async { 42 }).await;
412        assert_eq!(result.unwrap(), 42);
413
414        // Test timeout
415        let result = timeout(Duration::from_millis(10), async {
416            sleep(Duration::from_millis(50)).await;
417            42
418        })
419        .await;
420        assert!(result.is_err());
421    }
422
423    #[test]
424    fn test_retry_config() {
425        let config = RetryConfig::new()
426            .with_max_attempts(5)
427            .with_base_delay(Duration::from_millis(50))
428            .with_jitter(false);
429
430        assert_eq!(config.max_attempts, 5);
431        assert_eq!(config.base_delay, Duration::from_millis(50));
432        assert!(!config.jitter);
433
434        // Test delay calculation
435        assert_eq!(config.delay_for_attempt(0), Duration::ZERO);
436        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(50));
437        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(100));
438    }
439
440    #[tokio::test]
441    async fn test_retry_with_backoff() {
442        let counter = Arc::new(AtomicU32::new(0));
443        let counter_clone = counter.clone();
444
445        let config = RetryConfig::new()
446            .with_max_attempts(3)
447            .with_base_delay(Duration::from_millis(1))
448            .with_jitter(false);
449
450        let result = retry_with_backoff(
451            move || {
452                let count = counter_clone.fetch_add(1, Ordering::SeqCst);
453                async move {
454                    if count < 2 {
455                        Err("fail")
456                    } else {
457                        Ok("success")
458                    }
459                }
460            },
461            config,
462            |_| true,
463        )
464        .await;
465
466        assert_eq!(result.unwrap(), "success");
467        assert_eq!(counter.load(Ordering::SeqCst), 3);
468    }
469
470    #[tokio::test]
471    async fn test_circuit_breaker() {
472        let cb = CircuitBreaker::new(2, Duration::from_millis(10));
473        let counter = Arc::new(AtomicU32::new(0));
474
475        // First failure
476        let result = cb
477            .call({
478                let counter = counter.clone();
479                move || async move {
480                    counter.fetch_add(1, Ordering::SeqCst);
481                    Err::<(), _>("error")
482                }
483            })
484            .await;
485        assert!(matches!(result, Err(CircuitBreakerError::Operation(_))));
486        assert_eq!(cb.state(), CircuitState::Closed);
487
488        // Second failure - should open circuit
489        let result = cb
490            .call({
491                let counter = counter.clone();
492                move || async move {
493                    counter.fetch_add(1, Ordering::SeqCst);
494                    Err::<(), _>("error")
495                }
496            })
497            .await;
498        assert!(matches!(result, Err(CircuitBreakerError::Operation(_))));
499        assert_eq!(cb.state(), CircuitState::Open);
500
501        // Third attempt - should fail fast
502        let result: Result<(), CircuitBreakerError<&str>> = cb
503            .call({
504                let counter = counter.clone();
505                move || async move {
506                    counter.fetch_add(1, Ordering::SeqCst);
507                    Ok(())
508                }
509            })
510            .await;
511        assert!(matches!(result, Err(CircuitBreakerError::Open)));
512
513        // Counter should only be 2 (third attempt was blocked)
514        assert_eq!(counter.load(Ordering::SeqCst), 2);
515    }
516
517    #[test]
518    fn test_measure_time_macro() {
519        let result = measure_time!("test_operation", {
520            std::thread::sleep(Duration::from_millis(1));
521            42
522        });
523        assert_eq!(result, 42);
524    }
525}