Skip to main content

tower_resilience_adaptive/
algorithm.rs

1//! Adaptive concurrency control algorithms.
2//!
3//! This module provides different algorithms for dynamically adjusting
4//! concurrency limits based on observed latency and error rates.
5
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::time::Duration;
8use tower_resilience_core::aimd::{AimdConfig, AimdController};
9
10/// Trait for adaptive concurrency control algorithms.
11pub trait ConcurrencyAlgorithm: Send + Sync {
12    /// Record a successful request with its latency.
13    fn record_success(&self, latency: Duration);
14
15    /// Record a failed request.
16    fn record_failure(&self);
17
18    /// Record a dropped/cancelled request.
19    fn record_dropped(&self);
20
21    /// Get the current concurrency limit.
22    fn limit(&self) -> usize;
23
24    /// Get the minimum allowed limit.
25    fn min_limit(&self) -> usize;
26
27    /// Get the maximum allowed limit.
28    fn max_limit(&self) -> usize;
29}
30
31/// AIMD (Additive Increase Multiplicative Decrease) algorithm.
32///
33/// This is the classic TCP congestion control algorithm:
34/// - On success: increase limit by a fixed amount
35/// - On failure/timeout: decrease limit by a factor
36///
37/// The algorithm creates a "sawtooth" pattern as it probes for capacity.
38pub struct Aimd {
39    controller: AimdController,
40    /// Latency threshold above which we consider the system congested.
41    latency_threshold: Duration,
42}
43
44impl Aimd {
45    /// Create a new AIMD algorithm with the given configuration.
46    pub fn new(config: AimdConfig, latency_threshold: Duration) -> Self {
47        Self {
48            controller: AimdController::new(config),
49            latency_threshold,
50        }
51    }
52
53    /// Create a builder for configuring AIMD.
54    pub fn builder() -> AimdBuilder {
55        AimdBuilder::default()
56    }
57}
58
59impl ConcurrencyAlgorithm for Aimd {
60    fn record_success(&self, latency: Duration) {
61        if latency > self.latency_threshold {
62            // High latency indicates congestion
63            self.controller.record_failure();
64        } else {
65            self.controller.record_success();
66        }
67    }
68
69    fn record_failure(&self) {
70        self.controller.record_failure();
71    }
72
73    fn record_dropped(&self) {
74        // Dropped requests don't affect the limit
75    }
76
77    fn limit(&self) -> usize {
78        self.controller.limit()
79    }
80
81    fn min_limit(&self) -> usize {
82        self.controller.min_limit()
83    }
84
85    fn max_limit(&self) -> usize {
86        self.controller.max_limit()
87    }
88}
89
90/// Builder for AIMD algorithm.
91#[derive(Debug, Clone)]
92pub struct AimdBuilder {
93    initial_limit: usize,
94    min_limit: usize,
95    max_limit: usize,
96    increase_by: usize,
97    decrease_factor: f64,
98    latency_threshold: Duration,
99}
100
101impl Default for AimdBuilder {
102    fn default() -> Self {
103        Self {
104            initial_limit: 10,
105            min_limit: 1,
106            max_limit: 100,
107            increase_by: 1,
108            decrease_factor: 0.5,
109            latency_threshold: Duration::from_millis(100),
110        }
111    }
112}
113
114impl AimdBuilder {
115    /// Set the initial concurrency limit.
116    pub fn initial_limit(mut self, limit: usize) -> Self {
117        self.initial_limit = limit;
118        self
119    }
120
121    /// Set the minimum concurrency limit.
122    pub fn min_limit(mut self, limit: usize) -> Self {
123        self.min_limit = limit;
124        self
125    }
126
127    /// Set the maximum concurrency limit.
128    pub fn max_limit(mut self, limit: usize) -> Self {
129        self.max_limit = limit;
130        self
131    }
132
133    /// Set the additive increase amount per success.
134    pub fn increase_by(mut self, amount: usize) -> Self {
135        self.increase_by = amount;
136        self
137    }
138
139    /// Set the multiplicative decrease factor on failure.
140    pub fn decrease_factor(mut self, factor: f64) -> Self {
141        self.decrease_factor = factor;
142        self
143    }
144
145    /// Set the latency threshold for congestion detection.
146    ///
147    /// Requests taking longer than this are considered a congestion signal.
148    pub fn latency_threshold(mut self, threshold: Duration) -> Self {
149        self.latency_threshold = threshold;
150        self
151    }
152
153    /// Build the AIMD algorithm.
154    pub fn build(self) -> Aimd {
155        let config = AimdConfig::new()
156            .with_initial_limit(self.initial_limit)
157            .with_min_limit(self.min_limit)
158            .with_max_limit(self.max_limit)
159            .with_increase_by(self.increase_by)
160            .with_decrease_factor(self.decrease_factor);
161
162        Aimd::new(config, self.latency_threshold)
163    }
164}
165
166/// TCP Vegas algorithm for concurrency control.
167///
168/// Vegas uses RTT (round-trip time) measurements to detect congestion
169/// before it causes packet loss. It estimates the queue depth and adjusts
170/// the concurrency limit to maintain a target queue size.
171///
172/// This is more stable than AIMD and avoids the sawtooth pattern.
173pub struct Vegas {
174    /// Current limit
175    limit: AtomicUsize,
176    /// Minimum limit
177    min_limit: usize,
178    /// Maximum limit
179    max_limit: usize,
180    /// Minimum observed RTT (used as baseline)
181    min_rtt_nanos: AtomicU64,
182    /// Alpha threshold - if queue estimate < alpha, increase
183    alpha: usize,
184    /// Beta threshold - if queue estimate > beta, decrease
185    beta: usize,
186    /// Smoothing factor for RTT measurements
187    smoothing: f64,
188    /// Smoothed RTT in nanoseconds
189    smoothed_rtt_nanos: AtomicU64,
190    /// Number of samples collected
191    sample_count: AtomicUsize,
192    /// Minimum samples before adjusting
193    min_samples: usize,
194}
195
196impl Vegas {
197    /// Create a new Vegas algorithm.
198    pub fn new(
199        initial_limit: usize,
200        min_limit: usize,
201        max_limit: usize,
202        alpha: usize,
203        beta: usize,
204    ) -> Self {
205        Self {
206            limit: AtomicUsize::new(initial_limit.clamp(min_limit, max_limit)),
207            min_limit,
208            max_limit,
209            min_rtt_nanos: AtomicU64::new(u64::MAX),
210            alpha,
211            beta,
212            smoothing: 0.5,
213            smoothed_rtt_nanos: AtomicU64::new(0),
214            sample_count: AtomicUsize::new(0),
215            min_samples: 10,
216        }
217    }
218
219    /// Create a builder for Vegas.
220    pub fn builder() -> VegasBuilder {
221        VegasBuilder::default()
222    }
223
224    fn update_rtt(&self, rtt: Duration) {
225        let rtt_nanos = rtt.as_nanos() as u64;
226
227        // Update minimum RTT
228        let mut current_min = self.min_rtt_nanos.load(Ordering::Relaxed);
229        while rtt_nanos < current_min {
230            match self.min_rtt_nanos.compare_exchange_weak(
231                current_min,
232                rtt_nanos,
233                Ordering::Relaxed,
234                Ordering::Relaxed,
235            ) {
236                Ok(_) => break,
237                Err(c) => current_min = c,
238            }
239        }
240
241        // Update smoothed RTT using exponential moving average
242        let current_smoothed = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
243        let new_smoothed = if current_smoothed == 0 {
244            rtt_nanos
245        } else {
246            (self.smoothing * rtt_nanos as f64 + (1.0 - self.smoothing) * current_smoothed as f64)
247                as u64
248        };
249        self.smoothed_rtt_nanos
250            .store(new_smoothed, Ordering::Relaxed);
251
252        self.sample_count.fetch_add(1, Ordering::Relaxed);
253    }
254
255    fn adjust_limit(&self) {
256        // Don't adjust until we have enough samples
257        if self.sample_count.load(Ordering::Relaxed) < self.min_samples {
258            return;
259        }
260
261        let min_rtt = self.min_rtt_nanos.load(Ordering::Relaxed);
262        let smoothed_rtt = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
263
264        if min_rtt == u64::MAX || min_rtt == 0 || smoothed_rtt == 0 {
265            return;
266        }
267
268        let current_limit = self.limit.load(Ordering::Relaxed);
269
270        // Estimate queue depth: (smoothed_rtt - min_rtt) / min_rtt * current_limit
271        // This estimates how many requests are "queued" beyond the minimum RTT
272        let queue_estimate = if smoothed_rtt > min_rtt {
273            ((smoothed_rtt - min_rtt) as f64 / min_rtt as f64 * current_limit as f64) as usize
274        } else {
275            0
276        };
277
278        let new_limit = if queue_estimate < self.alpha {
279            // Under-utilized, increase
280            (current_limit + 1).min(self.max_limit)
281        } else if queue_estimate > self.beta {
282            // Congested, decrease
283            (current_limit.saturating_sub(1)).max(self.min_limit)
284        } else {
285            // In the sweet spot
286            current_limit
287        };
288
289        self.limit.store(new_limit, Ordering::Relaxed);
290    }
291}
292
293impl ConcurrencyAlgorithm for Vegas {
294    fn record_success(&self, latency: Duration) {
295        self.update_rtt(latency);
296        self.adjust_limit();
297    }
298
299    fn record_failure(&self) {
300        // On error, decrease limit immediately
301        let current = self.limit.load(Ordering::Relaxed);
302        let new_limit = (current / 2).max(self.min_limit);
303        self.limit.store(new_limit, Ordering::Relaxed);
304    }
305
306    fn record_dropped(&self) {
307        // Dropped requests don't affect the limit
308    }
309
310    fn limit(&self) -> usize {
311        self.limit.load(Ordering::Relaxed)
312    }
313
314    fn min_limit(&self) -> usize {
315        self.min_limit
316    }
317
318    fn max_limit(&self) -> usize {
319        self.max_limit
320    }
321}
322
323/// Builder for Vegas algorithm.
324#[derive(Debug, Clone)]
325pub struct VegasBuilder {
326    initial_limit: usize,
327    min_limit: usize,
328    max_limit: usize,
329    alpha: usize,
330    beta: usize,
331}
332
333impl Default for VegasBuilder {
334    fn default() -> Self {
335        Self {
336            initial_limit: 10,
337            min_limit: 1,
338            max_limit: 100,
339            alpha: 3,
340            beta: 6,
341        }
342    }
343}
344
345impl VegasBuilder {
346    /// Set the initial concurrency limit.
347    pub fn initial_limit(mut self, limit: usize) -> Self {
348        self.initial_limit = limit;
349        self
350    }
351
352    /// Set the minimum concurrency limit.
353    pub fn min_limit(mut self, limit: usize) -> Self {
354        self.min_limit = limit;
355        self
356    }
357
358    /// Set the maximum concurrency limit.
359    pub fn max_limit(mut self, limit: usize) -> Self {
360        self.max_limit = limit;
361        self
362    }
363
364    /// Set the alpha threshold (queue depth for increase).
365    ///
366    /// When the estimated queue depth is below alpha, the limit increases.
367    pub fn alpha(mut self, alpha: usize) -> Self {
368        self.alpha = alpha;
369        self
370    }
371
372    /// Set the beta threshold (queue depth for decrease).
373    ///
374    /// When the estimated queue depth is above beta, the limit decreases.
375    pub fn beta(mut self, beta: usize) -> Self {
376        self.beta = beta;
377        self
378    }
379
380    /// Build the Vegas algorithm.
381    pub fn build(self) -> Vegas {
382        Vegas::new(
383            self.initial_limit,
384            self.min_limit,
385            self.max_limit,
386            self.alpha,
387            self.beta,
388        )
389    }
390}
391
392/// Algorithm selection enum for the adaptive limiter.
393pub enum Algorithm {
394    /// AIMD algorithm
395    Aimd(Aimd),
396    /// Vegas algorithm
397    Vegas(Vegas),
398}
399
400impl ConcurrencyAlgorithm for Algorithm {
401    fn record_success(&self, latency: Duration) {
402        match self {
403            Algorithm::Aimd(a) => a.record_success(latency),
404            Algorithm::Vegas(v) => v.record_success(latency),
405        }
406    }
407
408    fn record_failure(&self) {
409        match self {
410            Algorithm::Aimd(a) => a.record_failure(),
411            Algorithm::Vegas(v) => v.record_failure(),
412        }
413    }
414
415    fn record_dropped(&self) {
416        match self {
417            Algorithm::Aimd(a) => a.record_dropped(),
418            Algorithm::Vegas(v) => v.record_dropped(),
419        }
420    }
421
422    fn limit(&self) -> usize {
423        match self {
424            Algorithm::Aimd(a) => a.limit(),
425            Algorithm::Vegas(v) => v.limit(),
426        }
427    }
428
429    fn min_limit(&self) -> usize {
430        match self {
431            Algorithm::Aimd(a) => a.min_limit(),
432            Algorithm::Vegas(v) => v.min_limit(),
433        }
434    }
435
436    fn max_limit(&self) -> usize {
437        match self {
438            Algorithm::Aimd(a) => a.max_limit(),
439            Algorithm::Vegas(v) => v.max_limit(),
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_aimd_builder() {
450        let aimd = Aimd::builder()
451            .initial_limit(20)
452            .min_limit(5)
453            .max_limit(200)
454            .increase_by(2)
455            .decrease_factor(0.75)
456            .latency_threshold(Duration::from_millis(50))
457            .build();
458
459        assert_eq!(aimd.limit(), 20);
460        assert_eq!(aimd.min_limit(), 5);
461        assert_eq!(aimd.max_limit(), 200);
462    }
463
464    #[test]
465    fn test_aimd_success_increases() {
466        let aimd = Aimd::builder()
467            .initial_limit(10)
468            .increase_by(1)
469            .latency_threshold(Duration::from_millis(100))
470            .build();
471
472        // Fast request - should increase
473        aimd.record_success(Duration::from_millis(50));
474        assert_eq!(aimd.limit(), 11);
475    }
476
477    #[test]
478    fn test_aimd_high_latency_decreases() {
479        let aimd = Aimd::builder()
480            .initial_limit(10)
481            .decrease_factor(0.5)
482            .latency_threshold(Duration::from_millis(100))
483            .build();
484
485        // Slow request - should decrease
486        aimd.record_success(Duration::from_millis(150));
487        assert_eq!(aimd.limit(), 5);
488    }
489
490    #[test]
491    fn test_aimd_failure_decreases() {
492        let aimd = Aimd::builder()
493            .initial_limit(10)
494            .decrease_factor(0.5)
495            .build();
496
497        aimd.record_failure();
498        assert_eq!(aimd.limit(), 5);
499    }
500
501    #[test]
502    fn test_vegas_builder() {
503        let vegas = Vegas::builder()
504            .initial_limit(20)
505            .min_limit(5)
506            .max_limit(200)
507            .alpha(2)
508            .beta(8)
509            .build();
510
511        assert_eq!(vegas.limit(), 20);
512        assert_eq!(vegas.min_limit(), 5);
513        assert_eq!(vegas.max_limit(), 200);
514    }
515
516    #[test]
517    fn test_vegas_failure_decreases() {
518        let vegas = Vegas::builder().initial_limit(20).min_limit(1).build();
519
520        vegas.record_failure();
521        assert_eq!(vegas.limit(), 10);
522    }
523
524    #[test]
525    fn test_vegas_min_rtt_tracking() {
526        let vegas = Vegas::builder().initial_limit(10).build();
527
528        vegas.record_success(Duration::from_millis(100));
529        vegas.record_success(Duration::from_millis(50));
530        vegas.record_success(Duration::from_millis(75));
531
532        // Min RTT should be 50ms
533        let min_rtt = vegas.min_rtt_nanos.load(Ordering::Relaxed);
534        assert_eq!(min_rtt, Duration::from_millis(50).as_nanos() as u64);
535    }
536
537    #[test]
538    fn test_algorithm_enum() {
539        let aimd = Algorithm::Aimd(Aimd::builder().initial_limit(10).build());
540        assert_eq!(aimd.limit(), 10);
541
542        let vegas = Algorithm::Vegas(Vegas::builder().initial_limit(20).build());
543        assert_eq!(vegas.limit(), 20);
544    }
545}