tower_circuitbreaker/
lib.rs

1//! # DEPRECATED: Use `tower-resilience` instead
2//!
3//! This crate has been superseded by [`tower-resilience`](https://crates.io/crates/tower-resilience),
4//! which provides a comprehensive suite of resilience patterns including circuit breakers,
5//! bulkheads, retries, time limiters, caching, and rate limiting.
6//!
7//! ## Migration
8//!
9//! **Old (tower-circuitbreaker):**
10//! ```toml
11//! [dependencies]
12//! tower-circuitbreaker = "0.1"
13//! ```
14//!
15//! **New (tower-resilience):**
16//! ```toml
17//! [dependencies]
18//! tower-resilience = "0.2"
19//! # Or use the individual crate:
20//! tower-resilience-circuitbreaker = "0.3"
21//! ```
22//!
23//! **API Changes:**
24//! ```rust,ignore
25//! // Old API
26//! use tower_circuitbreaker::circuit_breaker_builder;
27//! let cb = circuit_breaker_builder::<String, ()>()
28//!     .failure_rate_threshold(0.5)
29//!     .build();
30//!
31//! // New API
32//! use tower_resilience::circuitbreaker::CircuitBreakerLayer;
33//! let cb = CircuitBreakerLayer::<String, ()>::builder()
34//!     .failure_rate_threshold(0.5)
35//!     .build();
36//! ```
37//!
38//! For more information, see:
39//! - Repository: <https://github.com/joshrotenberg/tower-resilience>
40//! - Documentation: <https://docs.rs/tower-resilience>
41//!
42//! ---
43//!
44//! ## Original Documentation
45//!
46//! A Tower middleware implementing circuit breaker behavior to improve the resilience of asynchronous services.
47//!
48//! ## Features
49//! - Circuit breaker states: Closed, Open, Half-Open
50//! - Configurable failure rate threshold and sliding window size
51//! - Customizable `failure_classifier` to define what counts as a failure
52//! - Metrics support via the `metrics` feature flag
53//! - Tracing support via the `tracing` feature flag
54//!
55//! ## Example
56//! ```rust
57//! use tower_circuitbreaker::circuit_breaker_builder;
58//! use tower::ServiceBuilder;
59//! use tower::service_fn;
60//! use tower::Service;
61//! use std::time::Duration;
62//!
63//! #[tokio::main]
64//! async fn main() {
65//!     // Build a circuit breaker layer with custom settings
66//!     let circuit_breaker_layer = circuit_breaker_builder::<_, ()>()
67//!         .failure_rate_threshold(0.3)
68//!         .sliding_window_size(50)
69//!         .wait_duration_in_open(Duration::from_secs(10))
70//!         .build();
71//!
72//!     // Create a minimal service that echoes the request
73//!     let my_service = service_fn(|req| async move { Ok::<_, ()>(req) });
74//!
75//!     // Wrap the service with the circuit breaker
76//!     let mut service = ServiceBuilder::new()
77//!         .layer(circuit_breaker_layer)
78//!         .service(my_service);
79//!
80//!     // Use the service
81//!     let response = Service::call(&mut service, "hello").await.unwrap();
82//!     assert_eq!(response, "hello");
83//!
84//!     // Error handling example
85//!     match Service::call(&mut service, "hello").await {
86//!         Ok(resp) => println!("got {}", resp),
87//!         Err(e) if e.is_circuit_open() => println!("circuit open"),
88//!         Err(e) => eprintln!("service error: {:?}", e.into_inner()),
89//!     }
90//! }
91//! ```
92//!
93//! ## Feature Flags
94//! - `metrics`: enables metrics collection using the `metrics` crate.
95//! - `tracing`: enables logging and tracing using the `tracing` crate.
96
97use crate::circuit::Circuit;
98use crate::config::CircuitBreakerConfig;
99use crate::layer::CircuitBreakerLayerBuilder;
100use futures::future::BoxFuture;
101#[cfg(feature = "metrics")]
102use metrics::{counter, describe_counter, describe_gauge};
103use std::sync::Arc;
104#[cfg(feature = "metrics")]
105use std::sync::Once;
106use std::task::{Context, Poll};
107use tokio::sync::Mutex;
108use tower::Service;
109#[cfg(feature = "tracing")]
110use tracing::debug;
111
112pub use circuit::CircuitState;
113pub use error::CircuitBreakerError;
114
115mod circuit;
116mod config;
117mod error;
118mod layer;
119
120pub(crate) type FailureClassifier<Res, Err> = dyn Fn(&Result<Res, Err>) -> bool + Send + Sync;
121pub(crate) type SharedFailureClassifier<Res, Err> = Arc<FailureClassifier<Res, Err>>;
122
123#[cfg(feature = "tracing")]
124pub(crate) static DEFAULT_CIRCUIT_BREAKER_NAME: &str = "<unnamed>";
125
126#[cfg(feature = "metrics")]
127static METRICS_INIT: Once = Once::new();
128
129/// Returns a new builder for a `CircuitBreakerLayer`.
130pub fn circuit_breaker_builder<Res, Err>() -> CircuitBreakerLayerBuilder<Res, Err> {
131    #[cfg(feature = "metrics")]
132    {
133        METRICS_INIT.call_once(|| {
134            describe_counter!(
135                "circuitbreaker_calls_total",
136                "Total number of calls through the circuit breaker"
137            );
138            describe_counter!(
139                "circuitbreaker_transitions_total",
140                "Total number of circuit breaker state transitions"
141            );
142            describe_gauge!(
143                "circuitbreaker_state",
144                "Current state of the circuit breaker"
145            );
146        });
147    }
148    CircuitBreakerLayerBuilder::default()
149}
150
151/// A Tower Service that applies circuit breaker logic to an inner service.
152///
153/// Manages the circuit state and controls calls to the inner service accordingly.
154pub struct CircuitBreaker<S, Res, Err> {
155    inner: S,
156    circuit: Arc<Mutex<Circuit>>,
157    config: Arc<CircuitBreakerConfig<Res, Err>>,
158}
159
160impl<S, Res, Err> CircuitBreaker<S, Res, Err> {
161    /// Creates a new `CircuitBreaker` wrapping the given service and configuration.
162    pub(crate) fn new(inner: S, config: Arc<CircuitBreakerConfig<Res, Err>>) -> Self {
163        Self {
164            inner,
165            circuit: Arc::new(Mutex::new(Circuit::new())),
166            config,
167        }
168    }
169
170    /// Forces the circuit into the open state.
171    pub async fn force_open(&self) {
172        let mut circuit = self.circuit.lock().await;
173        circuit.force_open();
174    }
175
176    /// Forces the circuit into the closed state.
177    pub async fn force_closed(&self) {
178        let mut circuit = self.circuit.lock().await;
179        circuit.force_closed();
180    }
181
182    /// Resets the circuit to the closed state and clears counts.
183    pub async fn reset(&self) {
184        let mut circuit = self.circuit.lock().await;
185        circuit.reset();
186    }
187
188    /// Returns the current state of the circuit.
189    pub async fn state(&self) -> CircuitState {
190        let circuit = self.circuit.lock().await;
191        circuit.state()
192    }
193}
194
195impl<S, Req, Res, Err> Service<Req> for CircuitBreaker<S, Res, Err>
196where
197    S: Service<Req, Response = Res, Error = Err> + Clone + Send + 'static,
198    S::Future: Send + 'static,
199    Res: Send + 'static,
200    Err: Send + 'static,
201    Req: Send + 'static,
202{
203    type Response = Res;
204    type Error = CircuitBreakerError<Err>;
205    type Future = BoxFuture<'static, Result<Res, Self::Error>>;
206
207    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208        self.inner
209            .poll_ready(cx)
210            .map_err(CircuitBreakerError::Inner)
211    }
212
213    fn call(&mut self, req: Req) -> Self::Future {
214        let config = Arc::clone(&self.config);
215        let circuit = Arc::clone(&self.circuit);
216        let mut inner = self.inner.clone();
217
218        Box::pin(async move {
219            #[cfg(feature = "tracing")]
220            {
221                let cb_name = config
222                    .name
223                    .as_deref()
224                    .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
225                debug!(
226                    breaker = cb_name,
227                    "Checking if call is permitted by circuit breaker"
228                );
229            }
230
231            #[cfg(feature = "tracing")]
232            let circuit_check_span = {
233                use tracing::{Level, span};
234                let state = {
235                    // To avoid holding the lock too long, just get the state for span field.
236                    let circuit = circuit.lock().await;
237                    circuit.state()
238                };
239                let cb_name = config
240                    .name
241                    .as_deref()
242                    .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
243                span!(Level::DEBUG, "circuit_check", breaker = cb_name, state = ?state)
244            };
245            #[cfg(feature = "tracing")]
246            let _enter = circuit_check_span.enter();
247
248            let permitted = {
249                let mut circuit = circuit.lock().await;
250                circuit.try_acquire(&config)
251            };
252
253            #[cfg(feature = "tracing")]
254            {
255                let cb_name = config
256                    .name
257                    .as_deref()
258                    .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
259                if permitted {
260                    tracing::trace!(breaker = cb_name, "circuit breaker permitted call");
261                } else {
262                    tracing::trace!(
263                        breaker = cb_name,
264                        "circuit breaker rejected call (circuit open)"
265                    );
266                }
267            }
268
269            if !permitted {
270                #[cfg(feature = "metrics")]
271                {
272                    let counter = counter!("circuitbreaker_calls_total", "outcome" => "rejected");
273                    counter.increment(1);
274                }
275                return Err(CircuitBreakerError::OpenCircuit);
276            }
277
278            let result = inner.call(req).await;
279
280            let mut circuit = circuit.lock().await;
281            if (config.failure_classifier)(&result) {
282                circuit.record_failure(&config);
283            } else {
284                circuit.record_success(&config);
285            }
286
287            result.map_err(CircuitBreakerError::Inner)
288        })
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use std::time::Duration;
296
297    fn dummy_config() -> CircuitBreakerConfig<(), ()> {
298        CircuitBreakerConfig {
299            failure_rate_threshold: 0.5,
300            sliding_window_size: 10,
301            wait_duration_in_open: Duration::from_secs(1),
302            permitted_calls_in_half_open: 1,
303            failure_classifier: Arc::new(|r| r.is_err()),
304            minimum_number_of_calls: 10,
305            #[cfg(feature = "tracing")]
306            name: Some("test".into()),
307        }
308    }
309
310    #[test]
311    fn transitions_to_open_on_high_failure_rate() {
312        let mut circuit = Circuit::new();
313        let config = dummy_config();
314
315        for _ in 0..6 {
316            circuit.record_failure(&config);
317        }
318        for _ in 0..4 {
319            circuit.record_success(&config);
320        }
321
322        assert_eq!(circuit.state(), CircuitState::Open);
323    }
324
325    #[test]
326    fn stays_closed_on_low_failure_rate() {
327        let mut circuit = Circuit::new();
328        let config = dummy_config();
329
330        for _ in 0..2 {
331            circuit.record_failure(&config);
332        }
333        for _ in 0..8 {
334            circuit.record_success(&config);
335        }
336
337        assert_eq!(circuit.state(), CircuitState::Closed);
338    }
339
340    #[tokio::test]
341    async fn manual_override_controls_work() {
342        let config = Arc::new(dummy_config());
343        let breaker = CircuitBreaker::new((), config);
344
345        breaker.force_open().await;
346        assert_eq!(breaker.state().await, CircuitState::Open);
347
348        breaker.force_closed().await;
349        assert_eq!(breaker.state().await, CircuitState::Closed);
350    }
351
352    #[test]
353    fn test_error_helpers() {
354        let err: CircuitBreakerError<&str> = CircuitBreakerError::OpenCircuit;
355        assert!(err.is_circuit_open());
356        assert_eq!(err.into_inner(), None);
357
358        let err2 = CircuitBreakerError::Inner("fail");
359        assert!(!err2.is_circuit_open());
360        assert_eq!(err2.into_inner(), Some("fail"));
361    }
362}