Skip to main content

rust_supervisor/policy/
backoff.rs

1//! Backoff timing for restart scheduling.
2//!
3//! This module owns exponential backoff calculation and deterministic jitter
4//! support. It does not sleep or spawn tasks.
5
6use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9/// Jitter source used by backoff calculation.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum JitterMode {
12    /// Adds no jitter and returns the exponential delay unchanged.
13    Disabled,
14    /// Adds deterministic jitter derived from this seed.
15    Deterministic {
16        /// Stable seed used by tests and reproducible simulations.
17        seed: u64,
18    },
19    /// Full jitter mode with uniform random sampling between zero and upper bound.
20    FullJitter {
21        /// Stable seed used for full jitter calculation.
22        seed: u64,
23    },
24    /// Decorrelated jitter mode that depends on previous wait duration.
25    DecorrelatedJitter {
26        /// Stable seed used for decorrelated jitter calculation.
27        seed: u64,
28    },
29}
30
31/// Exponential backoff configuration for restart start_counts.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub struct BackoffPolicy {
34    /// Initial delay for the first restart child_start_count.
35    pub initial: Duration,
36    /// Maximum delay allowed after exponential growth and jitter.
37    pub max: Duration,
38    /// Jitter percentage in the inclusive range from zero to one hundred.
39    pub jitter_percent: u8,
40    /// Stable runtime duration after which child_start_count counters may be reset.
41    pub reset_after: Duration,
42    /// Jitter mode used by the calculation.
43    pub jitter_mode: JitterMode,
44}
45
46impl BackoffPolicy {
47    /// Creates an exponential backoff policy.
48    ///
49    /// # Arguments
50    ///
51    /// - `initial`: First restart delay.
52    /// - `max`: Maximum restart delay.
53    /// - `jitter_percent`: Jitter percentage capped at one hundred.
54    /// - `reset_after`: Runtime duration after which counters may reset.
55    ///
56    /// # Returns
57    ///
58    /// Returns a [`BackoffPolicy`] with jitter disabled.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use std::time::Duration;
64    ///
65    /// let policy = rust_supervisor::policy::backoff::BackoffPolicy::new(
66    ///     Duration::from_millis(10),
67    ///     Duration::from_millis(100),
68    ///     0,
69    ///     Duration::from_secs(1),
70    /// );
71    /// assert_eq!(policy.delay_for_child_start_count(1), Duration::from_millis(10));
72    /// ```
73    pub fn new(
74        initial: Duration,
75        max: Duration,
76        jitter_percent: u8,
77        reset_after: Duration,
78    ) -> Self {
79        Self {
80            initial,
81            max,
82            jitter_percent: jitter_percent.min(100),
83            reset_after,
84            jitter_mode: JitterMode::Disabled,
85        }
86    }
87
88    /// Returns this policy with deterministic jitter enabled.
89    ///
90    /// # Arguments
91    ///
92    /// - `seed`: Stable seed used to derive jitter.
93    ///
94    /// # Returns
95    ///
96    /// Returns a new [`BackoffPolicy`] that keeps the same timing bounds.
97    pub fn with_deterministic_jitter(mut self, seed: u64) -> Self {
98        self.jitter_mode = JitterMode::Deterministic { seed };
99        self
100    }
101
102    /// Returns this policy with full jitter enabled.
103    ///
104    /// # Arguments
105    ///
106    /// - `seed`: Stable seed used for full jitter calculation.
107    ///
108    /// # Returns
109    ///
110    /// Returns a new [`BackoffPolicy`] with full jitter mode.
111    pub fn with_full_jitter(mut self, seed: u64) -> Self {
112        self.jitter_mode = JitterMode::FullJitter { seed };
113        self
114    }
115
116    /// Returns this policy with decorrelated jitter enabled.
117    ///
118    /// # Arguments
119    ///
120    /// - `seed`: Stable seed used for decorrelated jitter calculation.
121    ///
122    /// # Returns
123    ///
124    /// Returns a new [`BackoffPolicy`] with decorrelated jitter mode.
125    pub fn with_decorrelated_jitter(mut self, seed: u64) -> Self {
126        self.jitter_mode = JitterMode::DecorrelatedJitter { seed };
127        self
128    }
129
130    /// Calculates a restart delay for a one-based child_start_count number.
131    ///
132    /// # Arguments
133    ///
134    /// - `child_start_count`: One-based restart child_start_count. Zero is treated as one.
135    ///
136    /// # Returns
137    ///
138    /// Returns a delay capped by [`BackoffPolicy::max`].
139    pub fn delay_for_child_start_count(&self, child_start_count: u64) -> Duration {
140        let exponential = self.exponential_delay(child_start_count.max(1));
141        self.apply_jitter(exponential).min(self.max)
142    }
143
144    /// Reports whether a stable runtime duration should reset counters.
145    ///
146    /// # Arguments
147    ///
148    /// - `stable_for`: Duration for which the child has run without failure.
149    ///
150    /// # Returns
151    ///
152    /// Returns `true` when `stable_for` reaches [`BackoffPolicy::reset_after`].
153    pub fn should_reset(&self, stable_for: Duration) -> bool {
154        stable_for >= self.reset_after
155    }
156
157    /// Computes the unclamped exponential delay.
158    ///
159    /// # Arguments
160    ///
161    /// - `child_start_count`: One-based restart child_start_count.
162    ///
163    /// # Returns
164    ///
165    /// Returns the exponential delay before jitter is applied.
166    fn exponential_delay(&self, child_start_count: u64) -> Duration {
167        let shift = child_start_count.saturating_sub(1).min(32);
168        let multiplier = 1_u128 << shift;
169        let millis = self.initial.as_millis().saturating_mul(multiplier);
170        duration_from_millis(millis).min(self.max)
171    }
172
173    /// Applies bounded jitter to a base delay.
174    ///
175    /// # Arguments
176    ///
177    /// - `base`: Delay before jitter.
178    ///
179    /// # Returns
180    ///
181    /// Returns a jittered delay that never exceeds the configured maximum.
182    fn apply_jitter(&self, base: Duration) -> Duration {
183        if self.jitter_percent == 0 {
184            return base;
185        }
186
187        match self.jitter_mode {
188            JitterMode::Disabled => base,
189            JitterMode::Deterministic { seed } => {
190                let jitter = deterministic_jitter(base, self.jitter_percent, seed);
191                base.saturating_add(jitter)
192            }
193            JitterMode::FullJitter { seed } => calculate_full_jitter(base, self.max, seed),
194            JitterMode::DecorrelatedJitter { seed } => {
195                calculate_decorrelated_jitter(base, self.initial, self.max, seed)
196            }
197        }
198    }
199}
200
201/// Converts milliseconds into a duration without overflowing.
202///
203/// # Arguments
204///
205/// - `millis`: Millisecond count held in a wide integer.
206///
207/// # Returns
208///
209/// Returns a [`Duration`] capped at `u64::MAX` milliseconds.
210fn duration_from_millis(millis: u128) -> Duration {
211    Duration::from_millis(millis.min(u64::MAX as u128) as u64)
212}
213
214/// Derives deterministic positive jitter.
215///
216/// # Arguments
217///
218/// - `base`: Base delay.
219/// - `percent`: Jitter percentage.
220/// - `seed`: Stable seed.
221///
222/// # Returns
223///
224/// Returns a jitter duration between zero and the configured percentage.
225fn deterministic_jitter(base: Duration, percent: u8, seed: u64) -> Duration {
226    let max_jitter = base.as_millis().saturating_mul(percent as u128) / 100;
227    if max_jitter == 0 {
228        return Duration::ZERO;
229    }
230
231    let mixed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
232    duration_from_millis((mixed as u128) % (max_jitter + 1))
233}
234
235/// Calculates full jitter with uniform random sampling.
236///
237/// Full jitter uniformly samples between zero and min(base_delay, max_delay)
238/// to prevent thundering herd problems in distributed systems.
239///
240/// # Arguments
241///
242/// - `base`: Base exponential delay before jitter.
243/// - `max`: Maximum allowed delay cap.
244/// - `seed`: Stable seed for reproducible random sampling.
245///
246/// # Returns
247///
248/// Returns a jittered duration uniformly distributed between zero and upper bound.
249///
250/// # Examples
251///
252/// ```
253/// use std::time::Duration;
254/// use rust_supervisor::policy::backoff::calculate_full_jitter;
255///
256/// let delay = calculate_full_jitter(
257///     Duration::from_millis(100),
258///     Duration::from_millis(1000),
259///     42,
260/// );
261/// assert!(delay <= Duration::from_millis(100));
262/// ```
263pub fn calculate_full_jitter(base: Duration, max: Duration, seed: u64) -> Duration {
264    let upper_bound = std::cmp::min(base, max);
265    let upper_millis = upper_bound.as_millis();
266    if upper_millis == 0 {
267        return Duration::ZERO;
268    }
269
270    // Use simple LCG (Linear Congruential Generator) for deterministic randomness
271    let lcg_next = |state: &mut u64| -> u64 {
272        *state = state
273            .wrapping_mul(6_364_136_223_846_793_005)
274            .wrapping_add(1);
275        *state
276    };
277
278    let mut rng_state = seed;
279    let random_value = lcg_next(&mut rng_state);
280    let jitter_millis = (random_value as u128) % (upper_millis + 1);
281    duration_from_millis(jitter_millis)
282}
283
284/// Calculates decorrelated jitter that depends on previous wait duration.
285///
286/// Decorrelated jitter uses the formula: sleep = min(cap, random(base, sleep * 3))
287/// This prevents correlation between retry attempts while maintaining bounded delays.
288///
289/// # Arguments
290///
291/// - `base`: Initial base delay for first retry.
292/// - `initial`: Minimum delay floor.
293/// - `max`: Maximum delay cap.
294/// - `seed`: Stable seed for reproducible random sampling.
295///
296/// # Returns
297///
298/// Returns a decorrelated jittered duration. For first call, returns value
299/// between initial and min(base * 3, max). Subsequent calls should pass
300/// previous result as new base for decorrelation.
301///
302/// # Examples
303///
304/// ```
305/// use std::time::Duration;
306/// use rust_supervisor::policy::backoff::calculate_decorrelated_jitter;
307///
308/// let delay = calculate_decorrelated_jitter(
309///     Duration::from_millis(100),
310///     Duration::from_millis(10),
311///     Duration::from_millis(1000),
312///     42,
313/// );
314/// assert!(delay >= Duration::from_millis(10));
315/// assert!(delay <= Duration::from_millis(1000));
316/// ```
317pub fn calculate_decorrelated_jitter(
318    base: Duration,
319    initial: Duration,
320    max: Duration,
321    seed: u64,
322) -> Duration {
323    // Formula: sleep = min(cap, random(base, sleep * 3))
324    // For first call, use initial as lower bound and min(base * 3, max) as upper bound
325    let lower = initial.as_millis();
326    let upper_candidate = base.as_millis().saturating_mul(3);
327    let upper = std::cmp::min(upper_candidate, max.as_millis());
328
329    if upper <= lower {
330        return duration_from_millis(lower);
331    }
332
333    // Use simple LCG for deterministic randomness
334    let lcg_next = |state: &mut u64| -> u64 {
335        *state = state
336            .wrapping_mul(6_364_136_223_846_793_005)
337            .wrapping_add(1);
338        *state
339    };
340
341    let mut rng_state = seed;
342    let random_value = lcg_next(&mut rng_state);
343    let range = upper - lower;
344    let jitter_millis = lower + ((random_value as u128) % (range + 1));
345    duration_from_millis(jitter_millis)
346}
347
348/// Cold start budget tracker for limiting restarts during initial startup.
349///
350/// Tracks restart attempts within a time window after supervisor or child startup.
351/// When the budget is exhausted, tighter protection policies are applied to prevent
352/// resource exhaustion during the critical cold start period.
353#[derive(Debug, Clone)]
354pub struct ColdStartBudget {
355    /// Time window in seconds during which cold start budget applies.
356    pub window_secs: u64,
357    /// Maximum number of restarts allowed within the cold start window.
358    pub max_restarts: u32,
359    /// Current restart count within the window.
360    pub restart_count: u32,
361    /// Supervisor or child start time (Unix epoch seconds).
362    pub start_time_secs: u64,
363}
364
365impl ColdStartBudget {
366    /// Creates a new cold start budget tracker.
367    ///
368    /// # Arguments
369    ///
370    /// - `window_secs`: Time window in seconds for cold start period.
371    /// - `max_restarts`: Maximum restarts allowed within the window.
372    /// - `start_time_secs`: Start time as Unix epoch seconds.
373    ///
374    /// # Returns
375    ///
376    /// Returns a new [`ColdStartBudget`] with zero restart count.
377    ///
378    /// # Examples
379    ///
380    /// ```
381    /// use rust_supervisor::policy::backoff::ColdStartBudget;
382    ///
383    /// let budget = ColdStartBudget::new(300, 5, 1000);
384    /// assert_eq!(budget.get_restart_count(), 0);
385    /// assert!(!budget.is_exhausted(1000));
386    /// ```
387    pub fn new(window_secs: u64, max_restarts: u32, start_time_secs: u64) -> Self {
388        Self {
389            window_secs,
390            max_restarts,
391            restart_count: 0,
392            start_time_secs,
393        }
394    }
395
396    /// Records a restart attempt and checks if budget is exhausted.
397    ///
398    /// # Arguments
399    ///
400    /// - `current_time_secs`: Current time as Unix epoch seconds.
401    ///
402    /// # Returns
403    ///
404    /// Returns `true` if the cold start budget has been exhausted, `false` otherwise.
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use rust_supervisor::policy::backoff::ColdStartBudget;
410    ///
411    /// let mut budget = ColdStartBudget::new(300, 2, 1000);
412    /// assert!(!budget.record_restart(1010)); // First restart
413    /// assert!(!budget.record_restart(1020)); // Second restart
414    /// assert!(budget.record_restart(1030));  // Third restart exhausts budget
415    /// ```
416    pub fn record_restart(&mut self, current_time_secs: u64) -> bool {
417        // Check if we're still within the cold start window
418        let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
419        if elapsed > self.window_secs {
420            // Window expired, reset budget
421            self.restart_count = 1;
422            return false;
423        }
424
425        self.restart_count += 1;
426        self.restart_count > self.max_restarts
427    }
428
429    /// Checks if the cold start budget is currently exhausted.
430    ///
431    /// # Arguments
432    ///
433    /// - `current_time_secs`: Current time as Unix epoch seconds.
434    ///
435    /// # Returns
436    ///
437    /// Returns `true` if restart count exceeds limit within active window.
438    pub fn is_exhausted(&self, current_time_secs: u64) -> bool {
439        let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
440        if elapsed > self.window_secs {
441            return false; // Window expired
442        }
443        self.restart_count >= self.max_restarts
444    }
445
446    /// Returns the current restart count within the cold start window.
447    ///
448    /// # Returns
449    ///
450    /// Returns the number of restarts recorded in the current window.
451    pub fn get_restart_count(&self) -> u32 {
452        self.restart_count
453    }
454
455    /// Checks if the cold start window is still active.
456    ///
457    /// # Arguments
458    ///
459    /// - `current_time_secs`: Current time as Unix epoch seconds.
460    ///
461    /// # Returns
462    ///
463    /// Returns `true` if within the cold start time window.
464    pub fn is_window_active(&self, current_time_secs: u64) -> bool {
465        let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
466        elapsed <= self.window_secs
467    }
468}
469
470/// Hot loop detector for identifying rapid crash-restart cycles.
471///
472/// Detects when a child crashes and restarts too frequently within a sliding
473/// time window, indicating a potential hot loop condition that requires
474/// protective intervention.
475#[derive(Debug, Clone)]
476pub struct HotLoopDetector {
477    /// Sliding time window in seconds for detecting hot loops.
478    pub window_secs: u64,
479    /// Minimum number of restarts within window to trigger detection.
480    pub min_restarts: u32,
481    /// Timestamps of recent crashes (Unix epoch seconds).
482    pub crash_times: Vec<u64>,
483}
484
485impl HotLoopDetector {
486    /// Creates a new hot loop detector.
487    ///
488    /// # Arguments
489    ///
490    /// - `window_secs`: Sliding time window in seconds.
491    /// - `min_restarts`: Minimum restarts within window to trigger detection.
492    ///
493    /// # Returns
494    ///
495    /// Returns a new [`HotLoopDetector`] with empty crash history.
496    ///
497    /// # Examples
498    ///
499    /// ```
500    /// use rust_supervisor::policy::backoff::HotLoopDetector;
501    ///
502    /// let detector = HotLoopDetector::new(60, 5);
503    /// assert!(!detector.is_hot_loop_detected(1000));
504    /// ```
505    pub fn new(window_secs: u64, min_restarts: u32) -> Self {
506        Self {
507            window_secs,
508            min_restarts,
509            crash_times: Vec::new(),
510        }
511    }
512
513    /// Records a crash event and checks if hot loop is detected.
514    ///
515    /// # Arguments
516    ///
517    /// - `crash_time_secs`: Crash timestamp as Unix epoch seconds.
518    ///
519    /// # Returns
520    ///
521    /// Returns `true` if hot loop condition is detected, `false` otherwise.
522    ///
523    /// # Examples
524    ///
525    /// ```
526    /// use rust_supervisor::policy::backoff::HotLoopDetector;
527    ///
528    /// let mut detector = HotLoopDetector::new(60, 3);
529    /// detector.record_crash(1000);
530    /// detector.record_crash(1010);
531    /// detector.record_crash(1020);
532    /// assert!(detector.is_hot_loop_detected(1020)); // 3 crashes in 20 seconds
533    /// ```
534    pub fn record_crash(&mut self, crash_time_secs: u64) -> bool {
535        // Add new crash timestamp
536        self.crash_times.push(crash_time_secs);
537
538        // Remove timestamps outside the sliding window
539        let cutoff = crash_time_secs.saturating_sub(self.window_secs);
540        self.crash_times.retain(|&t| t > cutoff);
541
542        // Check if we've exceeded the threshold
543        self.is_hot_loop_detected(crash_time_secs)
544    }
545
546    /// Checks if hot loop condition is currently detected.
547    ///
548    /// # Arguments
549    ///
550    /// - `current_time_secs`: Current time as Unix epoch seconds.
551    ///
552    /// # Returns
553    ///
554    /// Returns `true` if crash count within window meets or exceeds threshold.
555    pub fn is_hot_loop_detected(&self, current_time_secs: u64) -> bool {
556        let cutoff = current_time_secs.saturating_sub(self.window_secs);
557        let crashes_in_window = self.crash_times.iter().filter(|&&t| t > cutoff).count();
558        crashes_in_window >= self.min_restarts as usize
559    }
560
561    /// Returns the number of crashes within the current sliding window.
562    ///
563    /// # Arguments
564    ///
565    /// - `current_time_secs`: Current time as Unix epoch seconds.
566    ///
567    /// # Returns
568    ///
569    /// Returns the count of crashes within the active window.
570    pub fn get_crash_count_in_window(&self, current_time_secs: u64) -> usize {
571        let cutoff = current_time_secs.saturating_sub(self.window_secs);
572        self.crash_times.iter().filter(|&&t| t > cutoff).count()
573    }
574
575    /// Clears the crash history, typically called after successful stable runtime.
576    pub fn clear_history(&mut self) {
577        self.crash_times.clear();
578    }
579}
580
581#[cfg(test)]
582mod backoff_extended_tests {
583    use crate::policy::backoff::{
584        ColdStartBudget, HotLoopDetector, calculate_decorrelated_jitter, calculate_full_jitter,
585    };
586    use std::time::Duration;
587
588    /// Tests that cold start budget correctly tracks restarts within window and enforces limit.
589    #[test]
590    fn test_cold_start_budget_basic_tracking() {
591        let mut budget = ColdStartBudget::new(300, 3, 1000);
592
593        // Within window, under limit
594        assert!(!budget.record_restart(1010));
595        assert!(!budget.record_restart(1020));
596        assert!(!budget.record_restart(1030));
597
598        // Exceeds limit
599        assert!(budget.record_restart(1040));
600    }
601
602    /// Tests that cold start budget resets after window expiry.
603    #[test]
604    fn test_cold_start_window_expiry() {
605        let mut budget = ColdStartBudget::new(300, 2, 1000);
606
607        // Fill budget within window
608        budget.record_restart(1010);
609        budget.record_restart(1020);
610
611        // After window expires, budget resets
612        assert!(!budget.record_restart(1400)); // Outside 300s window
613        assert_eq!(budget.get_restart_count(), 1);
614    }
615
616    /// Tests that hot loop detector triggers when crash count reaches threshold in window.
617    #[test]
618    fn test_hot_loop_detection_basic() {
619        let mut detector = HotLoopDetector::new(60, 3);
620
621        detector.record_crash(1000);
622        detector.record_crash(1010);
623        assert!(!detector.is_hot_loop_detected(1010)); // Only 2 crashes
624
625        detector.record_crash(1020);
626        assert!(detector.is_hot_loop_detected(1020)); // 3 crashes in window
627    }
628
629    /// Tests that hot loop sliding window correctly expires old crashes.
630    #[test]
631    fn test_hot_loop_sliding_window() {
632        let mut detector = HotLoopDetector::new(60, 3);
633
634        detector.record_crash(1000);
635        detector.record_crash(1010);
636        detector.record_crash(1020);
637        assert!(detector.is_hot_loop_detected(1020));
638
639        // After window slides past first crashes
640        assert!(!detector.is_hot_loop_detected(1070)); // Only 1 crash in last 60s
641    }
642
643    /// Tests that full jitter calculation respects base delay upper bound.
644    #[test]
645    fn test_full_jitter_bounds() {
646        let delay =
647            calculate_full_jitter(Duration::from_millis(100), Duration::from_millis(1000), 42);
648        assert!(delay <= Duration::from_millis(100)); // Capped by base
649    }
650
651    /// Tests that decorrelated jitter calculation stays within initial and max bounds.
652    #[test]
653    fn test_decorrelated_jitter_bounds() {
654        let delay = calculate_decorrelated_jitter(
655            Duration::from_millis(100),
656            Duration::from_millis(10),
657            Duration::from_millis(1000),
658            42,
659        );
660        assert!(delay >= Duration::from_millis(10)); // At least initial
661        assert!(delay <= Duration::from_millis(1000)); // At most max
662    }
663}