tower_resilience_adaptive/
algorithm.rs1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::time::Duration;
8use tower_resilience_core::aimd::{AimdConfig, AimdController};
9
10pub trait ConcurrencyAlgorithm: Send + Sync {
12 fn record_success(&self, latency: Duration);
14
15 fn record_failure(&self);
17
18 fn record_dropped(&self);
20
21 fn limit(&self) -> usize;
23
24 fn min_limit(&self) -> usize;
26
27 fn max_limit(&self) -> usize;
29}
30
31pub struct Aimd {
39 controller: AimdController,
40 latency_threshold: Duration,
42}
43
44impl Aimd {
45 pub fn new(config: AimdConfig, latency_threshold: Duration) -> Self {
47 Self {
48 controller: AimdController::new(config),
49 latency_threshold,
50 }
51 }
52
53 pub fn builder() -> AimdBuilder {
55 AimdBuilder::default()
56 }
57}
58
59impl ConcurrencyAlgorithm for Aimd {
60 fn record_success(&self, latency: Duration) {
61 if latency > self.latency_threshold {
62 self.controller.record_failure();
64 } else {
65 self.controller.record_success();
66 }
67 }
68
69 fn record_failure(&self) {
70 self.controller.record_failure();
71 }
72
73 fn record_dropped(&self) {
74 }
76
77 fn limit(&self) -> usize {
78 self.controller.limit()
79 }
80
81 fn min_limit(&self) -> usize {
82 self.controller.min_limit()
83 }
84
85 fn max_limit(&self) -> usize {
86 self.controller.max_limit()
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct AimdBuilder {
93 initial_limit: usize,
94 min_limit: usize,
95 max_limit: usize,
96 increase_by: usize,
97 decrease_factor: f64,
98 latency_threshold: Duration,
99}
100
101impl Default for AimdBuilder {
102 fn default() -> Self {
103 Self {
104 initial_limit: 10,
105 min_limit: 1,
106 max_limit: 100,
107 increase_by: 1,
108 decrease_factor: 0.5,
109 latency_threshold: Duration::from_millis(100),
110 }
111 }
112}
113
114impl AimdBuilder {
115 pub fn initial_limit(mut self, limit: usize) -> Self {
117 self.initial_limit = limit;
118 self
119 }
120
121 pub fn min_limit(mut self, limit: usize) -> Self {
123 self.min_limit = limit;
124 self
125 }
126
127 pub fn max_limit(mut self, limit: usize) -> Self {
129 self.max_limit = limit;
130 self
131 }
132
133 pub fn increase_by(mut self, amount: usize) -> Self {
135 self.increase_by = amount;
136 self
137 }
138
139 pub fn decrease_factor(mut self, factor: f64) -> Self {
141 self.decrease_factor = factor;
142 self
143 }
144
145 pub fn latency_threshold(mut self, threshold: Duration) -> Self {
149 self.latency_threshold = threshold;
150 self
151 }
152
153 pub fn build(self) -> Aimd {
155 let config = AimdConfig::new()
156 .with_initial_limit(self.initial_limit)
157 .with_min_limit(self.min_limit)
158 .with_max_limit(self.max_limit)
159 .with_increase_by(self.increase_by)
160 .with_decrease_factor(self.decrease_factor);
161
162 Aimd::new(config, self.latency_threshold)
163 }
164}
165
166pub struct Vegas {
174 limit: AtomicUsize,
176 min_limit: usize,
178 max_limit: usize,
180 min_rtt_nanos: AtomicU64,
182 alpha: usize,
184 beta: usize,
186 smoothing: f64,
188 smoothed_rtt_nanos: AtomicU64,
190 sample_count: AtomicUsize,
192 min_samples: usize,
194}
195
196impl Vegas {
197 pub fn new(
199 initial_limit: usize,
200 min_limit: usize,
201 max_limit: usize,
202 alpha: usize,
203 beta: usize,
204 ) -> Self {
205 Self {
206 limit: AtomicUsize::new(initial_limit.clamp(min_limit, max_limit)),
207 min_limit,
208 max_limit,
209 min_rtt_nanos: AtomicU64::new(u64::MAX),
210 alpha,
211 beta,
212 smoothing: 0.5,
213 smoothed_rtt_nanos: AtomicU64::new(0),
214 sample_count: AtomicUsize::new(0),
215 min_samples: 10,
216 }
217 }
218
219 pub fn builder() -> VegasBuilder {
221 VegasBuilder::default()
222 }
223
224 fn update_rtt(&self, rtt: Duration) {
225 let rtt_nanos = rtt.as_nanos() as u64;
226
227 let mut current_min = self.min_rtt_nanos.load(Ordering::Relaxed);
229 while rtt_nanos < current_min {
230 match self.min_rtt_nanos.compare_exchange_weak(
231 current_min,
232 rtt_nanos,
233 Ordering::Relaxed,
234 Ordering::Relaxed,
235 ) {
236 Ok(_) => break,
237 Err(c) => current_min = c,
238 }
239 }
240
241 let current_smoothed = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
243 let new_smoothed = if current_smoothed == 0 {
244 rtt_nanos
245 } else {
246 (self.smoothing * rtt_nanos as f64 + (1.0 - self.smoothing) * current_smoothed as f64)
247 as u64
248 };
249 self.smoothed_rtt_nanos
250 .store(new_smoothed, Ordering::Relaxed);
251
252 self.sample_count.fetch_add(1, Ordering::Relaxed);
253 }
254
255 fn adjust_limit(&self) {
256 if self.sample_count.load(Ordering::Relaxed) < self.min_samples {
258 return;
259 }
260
261 let min_rtt = self.min_rtt_nanos.load(Ordering::Relaxed);
262 let smoothed_rtt = self.smoothed_rtt_nanos.load(Ordering::Relaxed);
263
264 if min_rtt == u64::MAX || min_rtt == 0 || smoothed_rtt == 0 {
265 return;
266 }
267
268 let current_limit = self.limit.load(Ordering::Relaxed);
269
270 let queue_estimate = if smoothed_rtt > min_rtt {
273 ((smoothed_rtt - min_rtt) as f64 / min_rtt as f64 * current_limit as f64) as usize
274 } else {
275 0
276 };
277
278 let new_limit = if queue_estimate < self.alpha {
279 (current_limit + 1).min(self.max_limit)
281 } else if queue_estimate > self.beta {
282 (current_limit.saturating_sub(1)).max(self.min_limit)
284 } else {
285 current_limit
287 };
288
289 self.limit.store(new_limit, Ordering::Relaxed);
290 }
291}
292
293impl ConcurrencyAlgorithm for Vegas {
294 fn record_success(&self, latency: Duration) {
295 self.update_rtt(latency);
296 self.adjust_limit();
297 }
298
299 fn record_failure(&self) {
300 let current = self.limit.load(Ordering::Relaxed);
302 let new_limit = (current / 2).max(self.min_limit);
303 self.limit.store(new_limit, Ordering::Relaxed);
304 }
305
306 fn record_dropped(&self) {
307 }
309
310 fn limit(&self) -> usize {
311 self.limit.load(Ordering::Relaxed)
312 }
313
314 fn min_limit(&self) -> usize {
315 self.min_limit
316 }
317
318 fn max_limit(&self) -> usize {
319 self.max_limit
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct VegasBuilder {
326 initial_limit: usize,
327 min_limit: usize,
328 max_limit: usize,
329 alpha: usize,
330 beta: usize,
331}
332
333impl Default for VegasBuilder {
334 fn default() -> Self {
335 Self {
336 initial_limit: 10,
337 min_limit: 1,
338 max_limit: 100,
339 alpha: 3,
340 beta: 6,
341 }
342 }
343}
344
345impl VegasBuilder {
346 pub fn initial_limit(mut self, limit: usize) -> Self {
348 self.initial_limit = limit;
349 self
350 }
351
352 pub fn min_limit(mut self, limit: usize) -> Self {
354 self.min_limit = limit;
355 self
356 }
357
358 pub fn max_limit(mut self, limit: usize) -> Self {
360 self.max_limit = limit;
361 self
362 }
363
364 pub fn alpha(mut self, alpha: usize) -> Self {
368 self.alpha = alpha;
369 self
370 }
371
372 pub fn beta(mut self, beta: usize) -> Self {
376 self.beta = beta;
377 self
378 }
379
380 pub fn build(self) -> Vegas {
382 Vegas::new(
383 self.initial_limit,
384 self.min_limit,
385 self.max_limit,
386 self.alpha,
387 self.beta,
388 )
389 }
390}
391
392pub enum Algorithm {
394 Aimd(Aimd),
396 Vegas(Vegas),
398}
399
400impl ConcurrencyAlgorithm for Algorithm {
401 fn record_success(&self, latency: Duration) {
402 match self {
403 Algorithm::Aimd(a) => a.record_success(latency),
404 Algorithm::Vegas(v) => v.record_success(latency),
405 }
406 }
407
408 fn record_failure(&self) {
409 match self {
410 Algorithm::Aimd(a) => a.record_failure(),
411 Algorithm::Vegas(v) => v.record_failure(),
412 }
413 }
414
415 fn record_dropped(&self) {
416 match self {
417 Algorithm::Aimd(a) => a.record_dropped(),
418 Algorithm::Vegas(v) => v.record_dropped(),
419 }
420 }
421
422 fn limit(&self) -> usize {
423 match self {
424 Algorithm::Aimd(a) => a.limit(),
425 Algorithm::Vegas(v) => v.limit(),
426 }
427 }
428
429 fn min_limit(&self) -> usize {
430 match self {
431 Algorithm::Aimd(a) => a.min_limit(),
432 Algorithm::Vegas(v) => v.min_limit(),
433 }
434 }
435
436 fn max_limit(&self) -> usize {
437 match self {
438 Algorithm::Aimd(a) => a.max_limit(),
439 Algorithm::Vegas(v) => v.max_limit(),
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_aimd_builder() {
450 let aimd = Aimd::builder()
451 .initial_limit(20)
452 .min_limit(5)
453 .max_limit(200)
454 .increase_by(2)
455 .decrease_factor(0.75)
456 .latency_threshold(Duration::from_millis(50))
457 .build();
458
459 assert_eq!(aimd.limit(), 20);
460 assert_eq!(aimd.min_limit(), 5);
461 assert_eq!(aimd.max_limit(), 200);
462 }
463
464 #[test]
465 fn test_aimd_success_increases() {
466 let aimd = Aimd::builder()
467 .initial_limit(10)
468 .increase_by(1)
469 .latency_threshold(Duration::from_millis(100))
470 .build();
471
472 aimd.record_success(Duration::from_millis(50));
474 assert_eq!(aimd.limit(), 11);
475 }
476
477 #[test]
478 fn test_aimd_high_latency_decreases() {
479 let aimd = Aimd::builder()
480 .initial_limit(10)
481 .decrease_factor(0.5)
482 .latency_threshold(Duration::from_millis(100))
483 .build();
484
485 aimd.record_success(Duration::from_millis(150));
487 assert_eq!(aimd.limit(), 5);
488 }
489
490 #[test]
491 fn test_aimd_failure_decreases() {
492 let aimd = Aimd::builder()
493 .initial_limit(10)
494 .decrease_factor(0.5)
495 .build();
496
497 aimd.record_failure();
498 assert_eq!(aimd.limit(), 5);
499 }
500
501 #[test]
502 fn test_vegas_builder() {
503 let vegas = Vegas::builder()
504 .initial_limit(20)
505 .min_limit(5)
506 .max_limit(200)
507 .alpha(2)
508 .beta(8)
509 .build();
510
511 assert_eq!(vegas.limit(), 20);
512 assert_eq!(vegas.min_limit(), 5);
513 assert_eq!(vegas.max_limit(), 200);
514 }
515
516 #[test]
517 fn test_vegas_failure_decreases() {
518 let vegas = Vegas::builder().initial_limit(20).min_limit(1).build();
519
520 vegas.record_failure();
521 assert_eq!(vegas.limit(), 10);
522 }
523
524 #[test]
525 fn test_vegas_min_rtt_tracking() {
526 let vegas = Vegas::builder().initial_limit(10).build();
527
528 vegas.record_success(Duration::from_millis(100));
529 vegas.record_success(Duration::from_millis(50));
530 vegas.record_success(Duration::from_millis(75));
531
532 let min_rtt = vegas.min_rtt_nanos.load(Ordering::Relaxed);
534 assert_eq!(min_rtt, Duration::from_millis(50).as_nanos() as u64);
535 }
536
537 #[test]
538 fn test_algorithm_enum() {
539 let aimd = Algorithm::Aimd(Aimd::builder().initial_limit(10).build());
540 assert_eq!(aimd.limit(), 10);
541
542 let vegas = Algorithm::Vegas(Vegas::builder().initial_limit(20).build());
543 assert_eq!(vegas.limit(), 20);
544 }
545}