tower_fault/
error.rs

1//! # Error injection for `tower`
2//!
3//! Layer that injects errors randomly into a service. When an error is injected,
4//! the underlying service is not called.
5//!
6//! ## Usage
7//!
8//! ```rust
9//! use tower_fault::error::ErrorLayer;
10//! use tower::{service_fn, ServiceBuilder};
11//! # struct MyRequest { value: u64 };
12//! # async fn my_service(_req: MyRequest) -> Result<(), String> {
13//! #     Ok(())
14//! # }
15//!
16//! // Initialize an ErrorLayer with a 10% probability of returning
17//! // an error.
18//! let error_layer = ErrorLayer::new(0.1, |_: &MyRequest| String::from("error"));
19//!
20//! let service = ServiceBuilder::new()
21//!     .layer(error_layer)
22//!     .service(service_fn(my_service));
23//! ```
24//!
25//! ### Decider
26//!
27//! The __decider__ is used to determine if a latency should be injected
28//! or not. This can be a boolean, float, Bernoulli distribution, a
29//! closure, or a custom implementation of the [`Decider`] trait.
30//!
31//! For more information, see the [`decider`](crate::decider) module.
32//!
33//! ```rust
34//! use tower_fault::error::ErrorLayer;
35//! # struct MyRequest { value: u64 };
36//!
37//! // Never inject an error.
38//! ErrorLayer::new(false, |_: &MyRequest| String::from("error"));
39//! // Always inject an error.
40//! ErrorLayer::new(true, |_: &MyRequest| String::from("error"));
41//!
42//! // Inject an error 30% of the time.
43//! ErrorLayer::new(0.3, |_: &MyRequest| String::from("error"));
44//!
45//! // Inject an error based on the request content.
46//! ErrorLayer::new(|req: &MyRequest| req.value % 2 == 0, |_: &MyRequest| String::from("error"));
47//! ```
48//!
49//! ### Generator
50//!
51//! The __generator__ is a function that returns an error based on the
52//! request.
53//!
54//! ```rust
55//! use tower_fault::error::ErrorLayer;
56//! # struct MyRequest { value: u64 };
57//!
58//! // Customize the error based on the request payload
59//! ErrorLayer::new(false, |req: &MyRequest| format!("value: {}", req.value));
60//! ```
61//!
62
63use crate::decider::Decider;
64use std::{
65    future::Future,
66    marker::PhantomData,
67    pin::Pin,
68    task::{Context, Poll},
69};
70use tower::{Layer, Service};
71
72/// Layer that randomly trigger errors for the service.
73///
74/// This trigger errors based on the given probability and using
75/// a function to generate errors.
76#[derive(Clone, Debug)]
77pub struct ErrorLayer<'a, D, G> {
78    decider: D,
79    generator: G,
80    _phantom: PhantomData<&'a ()>,
81}
82
83impl<'a> ErrorLayer<'a, (), ()> {
84    /// Create a new `ErrorLayer` builder.
85    pub fn builder() -> Self {
86        Self {
87            decider: (),
88            generator: (),
89            _phantom: PhantomData,
90        }
91    }
92}
93
94impl<'a, D, G> ErrorLayer<'a, D, G> {
95    /// Create a new `ErrorLayer` builder with the given probability
96    /// and error generator.
97    pub fn new(decider: D, generator: G) -> Self {
98        Self {
99            decider,
100            generator,
101            _phantom: PhantomData,
102        }
103    }
104
105    /// Set the given decider to be used to determine if an error
106    /// should be injected.
107    pub fn with_decider<ND>(self, decider: ND) -> ErrorLayer<'a, ND, G> {
108        ErrorLayer {
109            decider,
110            generator: self.generator,
111            _phantom: PhantomData,
112        }
113    }
114
115    /// Set the given error generator to generate errors.
116    pub fn with_generator<NG>(self, generator: NG) -> ErrorLayer<'a, D, NG> {
117        ErrorLayer {
118            decider: self.decider,
119            generator,
120            _phantom: PhantomData,
121        }
122    }
123}
124
125impl<'a, D, G, S> Layer<S> for ErrorLayer<'a, D, G>
126where
127    D: Clone,
128    G: Clone,
129{
130    type Service = ErrorService<'a, D, G, S>;
131
132    fn layer(&self, inner: S) -> Self::Service {
133        ErrorService {
134            inner,
135            decider: self.decider.clone(),
136            generator: self.generator.clone(),
137            _phantom: PhantomData,
138        }
139    }
140}
141
142/// Service that randomly trigger errors instead of calling the underlying
143/// service.
144#[derive(Clone, Debug)]
145pub struct ErrorService<'a, D, G, S> {
146    inner: S,
147    decider: D,
148    generator: G,
149    _phantom: PhantomData<&'a ()>,
150}
151
152impl<'a, D, G, S, R> Service<R> for ErrorService<'a, D, G, S>
153where
154    D: Decider<R> + Clone,
155    G: Fn(&R) -> S::Error + Clone,
156    S: Service<R> + Send,
157    S::Future: Send + 'a,
158    S::Error: Send + 'a,
159{
160    type Response = S::Response;
161    type Error = S::Error;
162    type Future = ErrorFuture<'a, R, S>;
163
164    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        self.inner.poll_ready(cx)
166    }
167
168    fn call(&mut self, request: R) -> Self::Future {
169        if self.decider.decide(&request) {
170            let error = (self.generator)(&request);
171            return Box::pin(async move { Err(error) });
172        }
173
174        Box::pin(self.inner.call(request))
175    }
176}
177
178type ErrorFuture<'a, R, S> = Pin<
179    Box<
180        dyn Future<Output = Result<<S as Service<R>>::Response, <S as Service<R>>::Error>>
181            + Send
182            + 'a,
183    >,
184>;
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::test_utils::*;
190
191    #[tokio::test]
192    async fn error_success() {
193        let layer = ErrorLayer::new(0.0, |_: &()| String::from("error"));
194        let mut service = layer.layer(DummyService);
195
196        for _ in 0..1000 {
197            let res = service.call(()).await;
198            assert_eq!(res.unwrap(), String::from("ok"));
199        }
200    }
201
202    #[tokio::test]
203    async fn error_fail() {
204        let layer = ErrorLayer::new(1.0, |_: &()| String::from("error"));
205        let mut service = layer.layer(DummyService);
206
207        for _ in 0..1000 {
208            let res = service.call(()).await;
209            assert_eq!(res.unwrap_err(), String::from("error"));
210        }
211    }
212}