Skip to main content

tower_resilience_hedge/
lib.rs

1//! Hedging middleware for Tower services.
2//!
3//! Hedging reduces tail latency by executing parallel redundant requests.
4//! Instead of waiting for a slow request to complete, hedging fires additional
5//! requests after a configurable delay and returns whichever completes first.
6//!
7//! # Overview
8//!
9//! The hedging pattern is useful when:
10//! - Tail latency (P99/P999) is critical
11//! - Operations are idempotent and safe to retry
12//! - You can trade increased resource usage for lower latency
13//!
14//! # Presets
15//!
16//! ```rust
17//! use tower_resilience_hedge::HedgeLayer;
18//!
19//! let conservative = HedgeLayer::conservative(); // 500ms delay, 2 attempts
20//! let standard = HedgeLayer::standard();         // 100ms delay, 3 attempts
21//! let aggressive = HedgeLayer::aggressive();     // 50ms delay, 5 attempts
22//! ```
23//!
24//! # Modes
25//!
26//! ## Latency Mode (delay > 0)
27//!
28//! Wait a specified duration before firing hedge requests. This is the default
29//! and most common mode - it only sends extra requests if the primary is slow.
30//!
31//! ```rust,no_run
32//! use tower_resilience_hedge::HedgeLayer;
33//! use std::time::Duration;
34//!
35//! // No type parameters needed! Fire a hedge request if primary takes > 100ms
36//! let layer = HedgeLayer::builder()
37//!     .delay(Duration::from_millis(100))
38//!     .max_hedged_attempts(2)
39//!     .build();
40//! ```
41//!
42//! ## Parallel Mode (delay = 0)
43//!
44//! Fire all requests simultaneously and return the fastest response.
45//! Use when latency is critical and you can afford the resource cost.
46//!
47//! ```rust,no_run
48//! use tower_resilience_hedge::HedgeLayer;
49//!
50//! // No type parameters needed! Fire 3 requests immediately, return fastest
51//! let layer = HedgeLayer::builder()
52//!     .no_delay()
53//!     .max_hedged_attempts(3)
54//!     .build();
55//! ```
56//!
57//! # Example
58//!
59//! ```rust,no_run
60//! use tower::{Service, ServiceExt, Layer};
61//! use tower_resilience_hedge::HedgeLayer;
62//! use std::time::Duration;
63//!
64//! // Define a simple cloneable error type
65//! #[derive(Clone, Debug)]
66//! struct MyError;
67//! impl std::fmt::Display for MyError {
68//!     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69//!         write!(f, "MyError")
70//!     }
71//! }
72//! impl std::error::Error for MyError {}
73//!
74//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
75//! // Create a service that sometimes responds slowly
76//! let service = tower::service_fn(|req: String| async move {
77//!     // Simulate variable latency
78//!     Ok::<_, MyError>(format!("response: {}", req))
79//! });
80//!
81//! // Wrap with hedging - fire hedge after 50ms (no type parameters needed!)
82//! let hedge = HedgeLayer::builder()
83//!     .delay(Duration::from_millis(50))
84//!     .max_hedged_attempts(2)
85//!     .build();
86//!
87//! let mut service = hedge.layer(service);
88//!
89//! let response = service.ready().await?.call("hello".to_string()).await?;
90//! println!("Got response: {}", response);
91//! # Ok(())
92//! # }
93//! ```
94//!
95//! # Cancellation
96//!
97//! When one request succeeds, all other in-flight requests are cancelled
98//! by dropping their futures. This relies on the inner service supporting
99//! cooperative cancellation.
100//!
101//! # Type Requirements
102//!
103//! Hedging has specific trait bounds that differ from other resilience patterns:
104//!
105//! - **`Req: Clone`** - Required because the request is cloned to send parallel
106//!   requests. Each hedge attempt needs its own copy of the request.
107//!
108//! - **`E: Clone`** - Required for error handling. When multiple attempts fail,
109//!   errors need to be collected and stored to return the final error.
110//!
111//! If your request or error types don't implement `Clone`, consider:
112//! - Wrapping them in `Arc` (e.g., `Arc<MyRequest>`)
113//! - Using a different resilience pattern like Retry which doesn't require
114//!   cloning requests
115
116mod config;
117mod error;
118mod events;
119mod layer;
120
121pub use config::{HedgeConfig, HedgeConfigBuilder, HedgeDelay};
122pub use error::HedgeError;
123pub use events::HedgeEvent;
124pub use layer::HedgeLayer;
125
126use futures::future::BoxFuture;
127use std::sync::Arc;
128use std::task::{Context, Poll};
129use std::time::{Duration, Instant};
130use tower::Service;
131
132/// Hedging service that wraps an inner service.
133///
134/// This service executes parallel redundant requests to reduce tail latency.
135/// It fires additional "hedge" requests after a configurable delay and returns
136/// whichever request completes first successfully.
137///
138/// The type parameter is just the inner service type - request, response, and
139/// error types are derived from the service's associated types.
140pub struct Hedge<S> {
141    inner: S,
142    config: Arc<HedgeConfig>,
143}
144
145impl<S> Hedge<S> {
146    /// Create a new Hedge service with the given configuration.
147    pub fn new(inner: S, config: HedgeConfig) -> Self {
148        Self {
149            inner,
150            config: Arc::new(config),
151        }
152    }
153}
154
155impl<S: Clone> Clone for Hedge<S> {
156    fn clone(&self) -> Self {
157        Self {
158            inner: self.inner.clone(),
159            config: Arc::clone(&self.config),
160        }
161    }
162}
163
164impl<S, Req> Service<Req> for Hedge<S>
165where
166    S: Service<Req> + Clone + Send + 'static,
167    S::Response: Send + Sync + 'static,
168    S::Error: Clone + Send + Sync + 'static,
169    S::Future: Send,
170    Req: Clone + Send + Sync + 'static,
171{
172    type Response = S::Response;
173    type Error = HedgeError<S::Error>;
174    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
175
176    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.inner.poll_ready(cx).map_err(HedgeError::Inner)
178    }
179
180    fn call(&mut self, req: Req) -> Self::Future {
181        let config = Arc::clone(&self.config);
182        let inner = self.inner.clone();
183        // Replace the clone we just made with the ready service
184        let inner = std::mem::replace(&mut self.inner, inner);
185
186        Box::pin(async move { execute_with_hedging(inner, req, config).await })
187    }
188}
189
190/// Execute the request with hedging strategy
191async fn execute_with_hedging<S, Req>(
192    service: S,
193    req: Req,
194    config: Arc<HedgeConfig>,
195) -> Result<S::Response, HedgeError<S::Error>>
196where
197    S: Service<Req> + Clone + Send + 'static,
198    S::Response: Send + 'static,
199    S::Error: Clone + Send + 'static,
200    S::Future: Send,
201    Req: Clone + Send + 'static,
202{
203    use tokio::sync::mpsc;
204
205    let max_attempts = config.max_hedged_attempts;
206    let start = Instant::now();
207
208    // Emit primary started event
209    config.listeners.emit(&HedgeEvent::PrimaryStarted {
210        name: config.name.clone(),
211        timestamp: Instant::now(),
212    });
213
214    // Channel to collect results from all attempts
215    let (tx, mut rx) = mpsc::channel::<(usize, Result<S::Response, S::Error>)>(max_attempts);
216
217    // Spawn primary request
218    let mut service_clone = service.clone();
219    let req_clone = req.clone();
220    let tx_clone = tx.clone();
221    tokio::spawn(async move {
222        let result = service_clone.call(req_clone).await;
223        let _ = tx_clone.send((0, result)).await;
224    });
225
226    // Track spawned hedge tasks
227    let mut hedges_spawned: usize = 0;
228    let mut primary_error: Option<S::Error> = None;
229
230    // Get delay for first hedge
231    let first_delay = config.delay.get_delay(1);
232
233    // If we have more attempts and there's a delay, set up hedge timing
234    if max_attempts > 1 {
235        match first_delay {
236            Some(delay) if delay > Duration::ZERO => {
237                // Latency mode: wait for delay or result
238                let mut delay_fut = std::pin::pin!(tokio::time::sleep(delay));
239
240                loop {
241                    tokio::select! {
242                        biased;
243
244                        // Check for results
245                        Some((attempt, result)) = rx.recv() => {
246                            match &result {
247                                Ok(_) => {
248                                    let duration = start.elapsed();
249                                    if attempt == 0 {
250                                        config.listeners.emit(&HedgeEvent::PrimarySucceeded {
251                                            name: config.name.clone(),
252                                            duration,
253                                            hedges_cancelled: hedges_spawned,
254                                            timestamp: Instant::now(),
255                                        });
256                                    } else {
257                                        config.listeners.emit(&HedgeEvent::HedgeSucceeded {
258                                            name: config.name.clone(),
259                                            attempt,
260                                            duration,
261                                            primary_cancelled: true,
262                                            timestamp: Instant::now(),
263                                        });
264                                    }
265                                    return result.map_err(HedgeError::Inner);
266                                }
267                                Err(e) => {
268                                    // Store error, continue waiting for other attempts
269                                    if attempt == 0 {
270                                        primary_error = Some(e.clone());
271                                    }
272                                    // Check if all attempts exhausted
273                                    if hedges_spawned + 1 >= max_attempts {
274                                        // All spawned, check if this was the last result
275                                        config.listeners.emit(&HedgeEvent::AllFailed {
276                                            name: config.name.clone(),
277                                            attempts: hedges_spawned + 1,
278                                            timestamp: Instant::now(),
279                                        });
280                                        return Err(HedgeError::AllAttemptsFailed(
281                                            primary_error.unwrap_or_else(|| e.clone())
282                                        ));
283                                    }
284                                }
285                            }
286                        }
287
288                        // Delay elapsed, spawn hedge
289                        _ = &mut delay_fut, if hedges_spawned + 1 < max_attempts => {
290                            hedges_spawned += 1;
291                            let attempt_num = hedges_spawned;
292
293                            config.listeners.emit(&HedgeEvent::HedgeStarted {
294                                name: config.name.clone(),
295                                attempt: attempt_num,
296                                delay,
297                                timestamp: Instant::now(),
298                            });
299
300                            let mut svc = service.clone();
301                            let r = req.clone();
302                            let tx_c = tx.clone();
303                            tokio::spawn(async move {
304                                let result = svc.call(r).await;
305                                let _ = tx_c.send((attempt_num, result)).await;
306                            });
307
308                            // Set up next delay if more hedges available
309                            if hedges_spawned + 1 < max_attempts {
310                                if let Some(next_delay) = config.delay.get_delay(hedges_spawned + 1) {
311                                    delay_fut.set(tokio::time::sleep(next_delay));
312                                }
313                            }
314                        }
315
316                        else => {
317                            // No more hedges to spawn, just wait for results
318                            if let Some((attempt, result)) = rx.recv().await {
319                                match &result {
320                                    Ok(_) => {
321                                        let duration = start.elapsed();
322                                        if attempt == 0 {
323                                            config.listeners.emit(&HedgeEvent::PrimarySucceeded {
324                                                name: config.name.clone(),
325                                                duration,
326                                                hedges_cancelled: hedges_spawned,
327                                                timestamp: Instant::now(),
328                                            });
329                                        } else {
330                                            config.listeners.emit(&HedgeEvent::HedgeSucceeded {
331                                                name: config.name.clone(),
332                                                attempt,
333                                                duration,
334                                                primary_cancelled: attempt != 0,
335                                                timestamp: Instant::now(),
336                                            });
337                                        }
338                                        return result.map_err(HedgeError::Inner);
339                                    }
340                                    Err(e) => {
341                                        if attempt == 0 && primary_error.is_none() {
342                                            primary_error = Some(e.clone());
343                                        }
344                                    }
345                                }
346                            } else {
347                                // Channel closed, all senders dropped
348                                break;
349                            }
350                        }
351                    }
352                }
353            }
354            _ => {
355                // Parallel mode: spawn all hedges immediately
356                for i in 1..max_attempts {
357                    hedges_spawned += 1;
358
359                    config.listeners.emit(&HedgeEvent::HedgeStarted {
360                        name: config.name.clone(),
361                        attempt: i,
362                        delay: Duration::ZERO,
363                        timestamp: Instant::now(),
364                    });
365
366                    let mut svc = service.clone();
367                    let r = req.clone();
368                    let tx_c = tx.clone();
369                    tokio::spawn(async move {
370                        let result = svc.call(r).await;
371                        let _ = tx_c.send((i, result)).await;
372                    });
373                }
374            }
375        }
376    }
377
378    // Drop our sender so channel closes when all tasks complete
379    drop(tx);
380
381    // Wait for first success or all failures
382    let mut attempts_received: usize = 0;
383    let total_attempts = hedges_spawned + 1;
384
385    while let Some((attempt, result)) = rx.recv().await {
386        attempts_received += 1;
387
388        match result {
389            Ok(res) => {
390                let duration = start.elapsed();
391                if attempt == 0 {
392                    config.listeners.emit(&HedgeEvent::PrimarySucceeded {
393                        name: config.name.clone(),
394                        duration,
395                        hedges_cancelled: hedges_spawned.saturating_sub(attempts_received - 1),
396                        timestamp: Instant::now(),
397                    });
398                } else {
399                    config.listeners.emit(&HedgeEvent::HedgeSucceeded {
400                        name: config.name.clone(),
401                        attempt,
402                        duration,
403                        primary_cancelled: true,
404                        timestamp: Instant::now(),
405                    });
406                }
407                return Ok(res);
408            }
409            Err(e) => {
410                if primary_error.is_none() {
411                    primary_error = Some(e);
412                }
413            }
414        }
415    }
416
417    // All attempts failed
418    config.listeners.emit(&HedgeEvent::AllFailed {
419        name: config.name.clone(),
420        attempts: total_attempts,
421        timestamp: Instant::now(),
422    });
423
424    Err(HedgeError::AllAttemptsFailed(
425        primary_error.expect("at least one error should exist"),
426    ))
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use std::sync::atomic::{AtomicUsize, Ordering};
433    use tower::{Layer, ServiceExt};
434
435    #[derive(Clone, Debug)]
436    struct TestError;
437
438    #[tokio::test]
439    async fn test_primary_succeeds_no_hedge() {
440        let call_count = Arc::new(AtomicUsize::new(0));
441        let cc = Arc::clone(&call_count);
442
443        let service = tower::service_fn(move |_req: String| {
444            let cc = Arc::clone(&cc);
445            async move {
446                cc.fetch_add(1, Ordering::SeqCst);
447                Ok::<_, TestError>("success".to_string())
448            }
449        });
450
451        // No type parameters needed!
452        let layer = HedgeLayer::builder()
453            .delay(Duration::from_millis(100))
454            .max_hedged_attempts(2)
455            .build();
456
457        let mut service = layer.layer(service);
458
459        let result = service
460            .ready()
461            .await
462            .unwrap()
463            .call("test".to_string())
464            .await;
465        assert!(result.is_ok());
466
467        // Give a moment for any hedges to complete
468        tokio::time::sleep(Duration::from_millis(10)).await;
469
470        // Should only have called once since primary was fast
471        assert_eq!(call_count.load(Ordering::SeqCst), 1);
472    }
473
474    #[tokio::test]
475    async fn test_parallel_mode_all_called() {
476        let call_count = Arc::new(AtomicUsize::new(0));
477        let cc = Arc::clone(&call_count);
478
479        let service = tower::service_fn(move |_req: String| {
480            let cc = Arc::clone(&cc);
481            async move {
482                cc.fetch_add(1, Ordering::SeqCst);
483                tokio::time::sleep(Duration::from_millis(50)).await;
484                Ok::<_, TestError>("success".to_string())
485            }
486        });
487
488        // No type parameters needed!
489        let layer = HedgeLayer::builder()
490            .no_delay()
491            .max_hedged_attempts(3)
492            .build();
493
494        let mut service = layer.layer(service);
495
496        let result = service
497            .ready()
498            .await
499            .unwrap()
500            .call("test".to_string())
501            .await;
502        assert!(result.is_ok());
503
504        // Give time for all spawned tasks to increment counter
505        tokio::time::sleep(Duration::from_millis(100)).await;
506
507        // All 3 should have been called in parallel mode
508        assert_eq!(call_count.load(Ordering::SeqCst), 3);
509    }
510
511    #[tokio::test]
512    async fn test_hedge_fires_after_delay() {
513        let call_count = Arc::new(AtomicUsize::new(0));
514        let cc = Arc::clone(&call_count);
515
516        let service = tower::service_fn(move |_req: String| {
517            let cc = Arc::clone(&cc);
518            async move {
519                let count = cc.fetch_add(1, Ordering::SeqCst);
520                // First call is slow, second is fast
521                if count == 0 {
522                    tokio::time::sleep(Duration::from_millis(200)).await;
523                }
524                Ok::<_, TestError>("success".to_string())
525            }
526        });
527
528        // No type parameters needed!
529        let layer = HedgeLayer::builder()
530            .delay(Duration::from_millis(50))
531            .max_hedged_attempts(2)
532            .build();
533
534        let mut service = layer.layer(service);
535
536        let start = Instant::now();
537        let result = service
538            .ready()
539            .await
540            .unwrap()
541            .call("test".to_string())
542            .await;
543        let elapsed = start.elapsed();
544
545        assert!(result.is_ok());
546        // Should complete faster than 200ms because hedge succeeded
547        assert!(elapsed < Duration::from_millis(150));
548
549        // Both should have been called
550        tokio::time::sleep(Duration::from_millis(10)).await;
551        assert_eq!(call_count.load(Ordering::SeqCst), 2);
552    }
553
554    #[tokio::test]
555    async fn test_all_fail_returns_error() {
556        let service = tower::service_fn(|_req: String| async move { Err::<String, _>(TestError) });
557
558        // No type parameters needed!
559        let layer = HedgeLayer::builder()
560            .no_delay()
561            .max_hedged_attempts(2)
562            .build();
563
564        let mut service = layer.layer(service);
565
566        let result = service
567            .ready()
568            .await
569            .unwrap()
570            .call("test".to_string())
571            .await;
572        assert!(matches!(result, Err(HedgeError::AllAttemptsFailed(_))));
573    }
574
575    #[test]
576    fn test_preset_conservative() {
577        let _layer = HedgeLayer::conservative();
578    }
579
580    #[test]
581    fn test_preset_standard() {
582        let _layer = HedgeLayer::standard();
583    }
584
585    #[test]
586    fn test_preset_aggressive() {
587        let _layer = HedgeLayer::aggressive();
588    }
589}