Skip to main content

stygian_graph/adapters/
resilience.rs

1//! Resilience adapters
2
3use crate::domain::error::Result;
4use crate::ports::{CircuitBreaker, CircuitState, RateLimitConfig, RateLimiter};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11/// Circuit breaker implementation with configurable thresholds
12///
13/// Implements the circuit breaker pattern to prevent cascading failures.
14/// Tracks failure rate and automatically opens the circuit when threshold is exceeded.
15///
16/// # State Machine
17///
18/// - **Closed**: Normal operation, all requests pass through
19/// - **Open**: Too many failures, all requests fail fast
20/// - **`HalfOpen`**: Testing recovery, limited requests allowed
21///
22/// # Example
23///
24/// ```
25/// use stygian_graph::adapters::resilience::CircuitBreakerImpl;
26/// use stygian_graph::ports::{CircuitBreaker, CircuitState};
27///
28/// let cb = CircuitBreakerImpl::new(5, std::time::Duration::from_secs(30));
29/// // Record some failures
30/// cb.record_failure();
31/// cb.record_failure();
32/// // Check state
33/// assert!(matches!(cb.state(), CircuitState::Closed | CircuitState::Open));
34/// ```
35pub struct CircuitBreakerImpl {
36    state: Arc<RwLock<CircuitBreakerState>>,
37    failure_threshold: u32,
38    timeout: Duration,
39}
40
41#[derive(Debug)]
42struct CircuitBreakerState {
43    current: CircuitState,
44    failure_count: u32,
45    last_failure_time: Option<Instant>,
46}
47
48impl CircuitBreakerImpl {
49    /// Create a new circuit breaker
50    ///
51    /// # Arguments
52    ///
53    /// * `failure_threshold` - Number of failures before opening circuit
54    /// * `timeout` - Duration to wait before attempting reset
55    #[must_use]
56    pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
57        Self {
58            state: Arc::new(RwLock::new(CircuitBreakerState {
59                current: CircuitState::Closed,
60                failure_count: 0,
61                last_failure_time: None,
62            })),
63            failure_threshold,
64            timeout,
65        }
66    }
67
68    /// Check if timeout has elapsed and circuit can transition to `HalfOpen`
69    fn should_attempt_reset(&self, state: &CircuitBreakerState) -> bool {
70        if state.current != CircuitState::Open {
71            return false;
72        }
73
74        state
75            .last_failure_time
76            .is_some_and(|last_failure| last_failure.elapsed() >= self.timeout)
77    }
78}
79
80impl CircuitBreaker for CircuitBreakerImpl {
81    fn state(&self) -> CircuitState {
82        let state = self.state.read();
83        state.current
84    }
85
86    fn record_success(&self) {
87        let mut state = self.state.write();
88        // Success resets failures and closes circuit
89        state.failure_count = 0;
90        state.current = CircuitState::Closed;
91        state.last_failure_time = None;
92    }
93
94    fn record_failure(&self) {
95        let mut state = self.state.write();
96        state.failure_count += 1;
97        state.last_failure_time = Some(Instant::now());
98
99        // Open circuit if threshold exceeded
100        if state.failure_count >= self.failure_threshold {
101            state.current = CircuitState::Open;
102        }
103    }
104
105    fn attempt_reset(&self) -> bool {
106        let mut state = self.state.write();
107
108        if self.should_attempt_reset(&state) {
109            state.current = CircuitState::HalfOpen;
110            state.failure_count = 0;
111            true
112        } else {
113            false
114        }
115    }
116}
117
118/// No-op circuit breaker for testing
119///
120/// Always reports Closed state and ignores all state transitions.
121/// Useful for testing scenarios where circuit breaker behavior should be disabled.
122///
123/// # Example
124///
125/// ```
126/// use stygian_graph::adapters::resilience::NoopCircuitBreaker;
127/// use stygian_graph::ports::{CircuitBreaker, CircuitState};
128///
129/// let cb = NoopCircuitBreaker;
130/// cb.record_failure();
131/// assert_eq!(cb.state(), CircuitState::Closed);
132/// ```
133#[derive(Debug, Default, Clone, Copy)]
134pub struct NoopCircuitBreaker;
135
136impl CircuitBreaker for NoopCircuitBreaker {
137    fn state(&self) -> CircuitState {
138        CircuitState::Closed
139    }
140
141    fn record_success(&self) {
142        // No-op
143    }
144
145    fn record_failure(&self) {
146        // No-op
147    }
148
149    fn attempt_reset(&self) -> bool {
150        false
151    }
152}
153
154/// Token bucket rate limiter implementation
155///
156/// Implements rate limiting using the token bucket algorithm.
157/// Supports per-key rate limiting for multi-tenant scenarios.
158///
159/// # Algorithm
160///
161/// - Each key has a bucket with a maximum number of tokens
162/// - Tokens are consumed on each request
163/// - Tokens regenerate over time based on the configured window
164/// - Requests are rejected when bucket is empty
165///
166/// # Example
167///
168/// ```
169/// use stygian_graph::adapters::resilience::TokenBucketRateLimiter;
170/// use stygian_graph::ports::{RateLimiter, RateLimitConfig};
171/// use std::time::Duration;
172///
173/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
174/// let config = RateLimitConfig {
175///     max_requests: 10,
176///     window: Duration::from_secs(60),
177/// };
178/// let limiter = TokenBucketRateLimiter::new(config);
179/// assert!(limiter.check_rate_limit("api:test").await.unwrap());
180/// # });
181/// ```
182pub struct TokenBucketRateLimiter {
183    config: RateLimitConfig,
184    buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
185}
186
187#[derive(Debug)]
188struct TokenBucket {
189    tokens: u32,
190    last_refill: Instant,
191}
192
193impl TokenBucketRateLimiter {
194    /// Create a new token bucket rate limiter
195    #[must_use]
196    pub fn new(config: RateLimitConfig) -> Self {
197        Self {
198            config,
199            buckets: Arc::new(RwLock::new(HashMap::new())),
200        }
201    }
202
203    /// Refill tokens based on elapsed time
204    fn refill_tokens(&self, bucket: &mut TokenBucket) {
205        let elapsed = bucket.last_refill.elapsed();
206        let refill_rate = f64::from(self.config.max_requests) / self.config.window.as_secs_f64();
207        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
208        let tokens_to_add = (elapsed.as_secs_f64() * refill_rate) as u32;
209
210        if tokens_to_add > 0 {
211            bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_requests);
212            bucket.last_refill = Instant::now();
213        }
214    }
215}
216
217#[async_trait]
218impl RateLimiter for TokenBucketRateLimiter {
219    #[allow(clippy::significant_drop_tightening)]
220    async fn check_rate_limit(&self, key: &str) -> Result<bool> {
221        let has_tokens = {
222            let mut buckets = self.buckets.write();
223            let bucket = buckets
224                .entry(key.to_string())
225                .or_insert_with(|| TokenBucket {
226                    tokens: self.config.max_requests,
227                    last_refill: Instant::now(),
228                });
229            self.refill_tokens(bucket);
230            bucket.tokens > 0
231        };
232        Ok(has_tokens)
233    }
234
235    async fn record_request(&self, key: &str) -> Result<()> {
236        {
237            let mut buckets = self.buckets.write();
238            if let Some(bucket) = buckets.get_mut(key)
239                && bucket.tokens > 0
240            {
241                bucket.tokens -= 1;
242            }
243        }
244        Ok(())
245    }
246}
247
248/// No-op rate limiter for testing
249///
250/// Always allows requests and ignores all rate limit tracking.
251/// Useful for testing scenarios where rate limiting should be disabled.
252///
253/// # Example
254///
255/// ```
256/// use stygian_graph::adapters::resilience::NoopRateLimiter;
257/// use stygian_graph::ports::RateLimiter;
258///
259/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
260/// let limiter = NoopRateLimiter;
261/// assert!(limiter.check_rate_limit("any_key").await.unwrap());
262/// # });
263/// ```
264#[derive(Debug, Default, Clone, Copy)]
265pub struct NoopRateLimiter;
266
267#[async_trait]
268impl RateLimiter for NoopRateLimiter {
269    async fn check_rate_limit(&self, _key: &str) -> Result<bool> {
270        Ok(true)
271    }
272
273    async fn record_request(&self, _key: &str) -> Result<()> {
274        Ok(())
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_circuit_breaker_closes_on_success() {
284        let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
285        cb.record_failure();
286        cb.record_failure();
287        assert_eq!(cb.state(), CircuitState::Closed);
288        cb.record_success();
289        assert_eq!(cb.state(), CircuitState::Closed);
290    }
291
292    #[test]
293    fn test_circuit_breaker_opens_on_threshold() {
294        let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
295        cb.record_failure();
296        cb.record_failure();
297        cb.record_failure();
298        assert_eq!(cb.state(), CircuitState::Open);
299    }
300
301    #[test]
302    fn test_noop_circuit_breaker_always_closed() {
303        let cb = NoopCircuitBreaker;
304        assert_eq!(cb.state(), CircuitState::Closed);
305        cb.record_failure();
306        cb.record_failure();
307        cb.record_failure();
308        assert_eq!(cb.state(), CircuitState::Closed);
309    }
310
311    #[tokio::test]
312    async fn test_rate_limiter_allows_within_limit() -> Result<()> {
313        let config = RateLimitConfig {
314            max_requests: 10,
315            window: Duration::from_mins(1),
316        };
317        let limiter = TokenBucketRateLimiter::new(config);
318
319        assert!(limiter.check_rate_limit("test").await?);
320        limiter.record_request("test").await?;
321        assert!(limiter.check_rate_limit("test").await?);
322        Ok(())
323    }
324
325    #[tokio::test]
326    async fn test_noop_rate_limiter_always_allows() -> Result<()> {
327        let limiter = NoopRateLimiter;
328        assert!(limiter.check_rate_limit("any").await?);
329        limiter.record_request("any").await?;
330        assert!(limiter.check_rate_limit("any").await?);
331        Ok(())
332    }
333}
334
335// ─── Exponential Backoff Retry ────────────────────────────────────────────────
336
337/// Policy controlling exponential backoff retry behaviour.
338///
339/// Delays follow the formula: `base_delay * 2^attempt + rand(0..jitter_ms)`.
340/// The computed delay is capped at `max_delay`.
341///
342/// # Example
343///
344/// ```
345/// use stygian_graph::adapters::resilience::RetryPolicy;
346/// use std::time::Duration;
347///
348/// let policy = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(10));
349/// ```
350#[derive(Debug, Clone)]
351pub struct RetryPolicy {
352    /// Maximum number of retry attempts (not counting the initial call)
353    pub max_attempts: u32,
354    /// Base delay for the first retry
355    pub base_delay: Duration,
356    /// Maximum delay cap
357    pub max_delay: Duration,
358    /// Additional random jitter ceiling (milliseconds)
359    pub jitter_ms: u64,
360}
361
362impl RetryPolicy {
363    /// Create a new retry policy.
364    ///
365    /// # Example
366    ///
367    /// ```
368    /// use stygian_graph::adapters::resilience::RetryPolicy;
369    /// use std::time::Duration;
370    ///
371    /// let p = RetryPolicy::new(5, Duration::from_millis(200), Duration::from_secs(30));
372    /// assert_eq!(p.max_attempts, 5);
373    /// ```
374    #[must_use]
375    pub const fn new(max_attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
376        Self {
377            max_attempts,
378            base_delay,
379            max_delay,
380            jitter_ms: 50,
381        }
382    }
383
384    /// Override the jitter ceiling in milliseconds.
385    ///
386    /// # Example
387    ///
388    /// ```
389    /// use stygian_graph::adapters::resilience::RetryPolicy;
390    /// use std::time::Duration;
391    ///
392    /// let p = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(5))
393    ///     .with_jitter_ms(100);
394    /// assert_eq!(p.jitter_ms, 100);
395    /// ```
396    #[must_use]
397    pub const fn with_jitter_ms(mut self, jitter_ms: u64) -> Self {
398        self.jitter_ms = jitter_ms;
399        self
400    }
401
402    /// Compute the sleep duration for a given attempt index (0-based).
403    ///
404    /// # Example
405    ///
406    /// ```
407    /// use stygian_graph::adapters::resilience::RetryPolicy;
408    /// use std::time::Duration;
409    ///
410    /// let p = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(10))
411    ///     .with_jitter_ms(0);
412    /// // attempt 0 → 100 ms, attempt 1 → 200 ms, attempt 2 → 400 ms
413    /// assert_eq!(p.delay_for(0), Duration::from_millis(100));
414    /// assert_eq!(p.delay_for(1), Duration::from_millis(200));
415    /// ```
416    #[must_use]
417    pub fn delay_for(&self, attempt: u32) -> Duration {
418        let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
419        #[allow(clippy::cast_possible_truncation)]
420        let base_ms = self.base_delay.as_millis() as u64;
421        let jitter = if self.jitter_ms > 0 {
422            // Deterministic-enough without pulling in `rand`: use mix of attempt
423            // and current nanos as a low-cost entropy source.
424            let seed = u64::from(
425                std::time::SystemTime::now()
426                    .duration_since(std::time::UNIX_EPOCH)
427                    .unwrap_or_default()
428                    .subsec_nanos(),
429            );
430            (seed
431                .wrapping_mul(6_364_136_223_846_793_005)
432                .wrapping_add(1_442_695_040_888_963_407)
433                >> 33)
434                % self.jitter_ms
435        } else {
436            0
437        };
438        let ms = base_ms.saturating_mul(factor).saturating_add(jitter);
439        let delay = Duration::from_millis(ms);
440        delay.min(self.max_delay)
441    }
442}
443
444impl Default for RetryPolicy {
445    fn default() -> Self {
446        Self::new(3, Duration::from_millis(200), Duration::from_secs(30))
447    }
448}
449
450/// Execute an async operation with automatic retry according to a [`RetryPolicy`].
451///
452/// Returns the first `Ok` value, or the last `Err` after all attempts are exhausted.
453/// Each retry sleeps for an exponentially increasing delay with jitter.
454///
455/// # Example
456///
457/// ```
458/// use stygian_graph::adapters::resilience::{RetryPolicy, retry};
459/// use std::sync::atomic::{AtomicU32, Ordering};
460/// use std::sync::Arc;
461/// use std::time::Duration;
462///
463/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
464/// let attempts = Arc::new(AtomicU32::new(0));
465/// let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
466///     .with_jitter_ms(0);
467///
468/// let result = retry(&policy, || {
469///     let counter = Arc::clone(&attempts);
470///     async move {
471///         let n = counter.fetch_add(1, Ordering::SeqCst);
472///         if n < 2 { Err("not yet".to_string()) } else { Ok(n) }
473///     }
474/// }).await;
475///
476/// assert!(result.is_ok());
477/// assert_eq!(attempts.load(Ordering::SeqCst), 3);
478/// # });
479/// ```
480///
481/// # Errors
482///
483/// Returns whatever the supplied `f` future returns on the final attempt. The
484/// caller-supplied error type `E` is opaque to the helper; only `f` decides
485/// when a result is "final".
486pub async fn retry<F, Fut, T, E>(policy: &RetryPolicy, mut f: F) -> std::result::Result<T, E>
487where
488    F: FnMut() -> Fut,
489    Fut: std::future::Future<Output = std::result::Result<T, E>>,
490{
491    let mut result = f().await;
492    for attempt in 1..=policy.max_attempts {
493        if result.is_ok() {
494            return result;
495        }
496        tokio::time::sleep(policy.delay_for(attempt - 1)).await;
497        result = f().await;
498    }
499    result
500}
501
502#[cfg(test)]
503mod retry_tests {
504    use super::*;
505    use std::sync::Arc;
506    use std::sync::atomic::{AtomicU32, Ordering};
507
508    #[test]
509    fn delay_for_doubles() {
510        let p = RetryPolicy::new(4, Duration::from_millis(100), Duration::from_mins(1))
511            .with_jitter_ms(0);
512        assert_eq!(p.delay_for(0), Duration::from_millis(100));
513        assert_eq!(p.delay_for(1), Duration::from_millis(200));
514        assert_eq!(p.delay_for(2), Duration::from_millis(400));
515        assert_eq!(p.delay_for(3), Duration::from_millis(800));
516    }
517
518    #[test]
519    fn delay_capped_at_max() {
520        let p =
521            RetryPolicy::new(10, Duration::from_secs(1), Duration::from_secs(3)).with_jitter_ms(0);
522        // 1000 * 2^4 = 16_000 ms, capped at 3_000 ms
523        assert_eq!(p.delay_for(4), Duration::from_secs(3));
524    }
525
526    #[tokio::test]
527    async fn retry_succeeds_on_first_try() {
528        let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
529            .with_jitter_ms(0);
530        let result: std::result::Result<i32, &str> = retry(&policy, || async { Ok(42) }).await;
531        assert_eq!(result.ok(), Some(42));
532    }
533
534    #[tokio::test]
535    async fn retry_retries_until_success() {
536        let counter = Arc::new(AtomicU32::new(0));
537        let policy = RetryPolicy::new(4, Duration::from_millis(1), Duration::from_millis(50))
538            .with_jitter_ms(0);
539
540        let result: std::result::Result<u32, String> = retry(&policy, || {
541            let c = Arc::clone(&counter);
542            async move {
543                let n = c.fetch_add(1, Ordering::SeqCst);
544                if n < 3 {
545                    Err(format!("fail {n}"))
546                } else {
547                    Ok(n)
548                }
549            }
550        })
551        .await;
552
553        assert!(result.is_ok());
554        assert_eq!(counter.load(Ordering::SeqCst), 4); // 3 failures + 1 success
555    }
556
557    #[tokio::test]
558    async fn retry_exhausts_and_returns_last_error() {
559        let policy = RetryPolicy::new(2, Duration::from_millis(1), Duration::from_millis(10))
560            .with_jitter_ms(0);
561        let counter = Arc::new(AtomicU32::new(0));
562
563        let result: std::result::Result<(), String> = retry(&policy, || {
564            let c = Arc::clone(&counter);
565            async move {
566                c.fetch_add(1, Ordering::SeqCst);
567                Err("always fails".to_string())
568            }
569        })
570        .await;
571
572        assert!(result.is_err());
573        assert_eq!(counter.load(Ordering::SeqCst), 3); // initial + 2 retries
574    }
575}