tower_fallthrough_filter/
lib.rs

1use std::{
2    marker::PhantomData,
3    task::{Context, Poll},
4};
5
6use ::futures::{future::Either, ready};
7use tower::{Layer, Service};
8
9#[cfg(test)]
10pub mod test_util;
11
12#[cfg(feature = "futures")]
13pub mod futures;
14
15#[cfg(feature = "async")]
16pub use async_feature::{AsyncFilter, AsyncFilterLayer, AsyncFilterService};
17
18#[cfg(feature = "async")]
19mod async_feature;
20
21/// A filter that allows a service to be executed based on a condition
22///
23/// # Example
24/// ```rust
25/// # use tower_fallthrough_filter::Filter;
26///
27/// #[derive(Debug, Clone)]
28/// struct MyFilter;
29///
30/// impl<T> Filter<T> for MyFilter {
31///     fn matches(&self, _: &T) -> bool {
32///         true
33///     }
34/// }
35///
36/// let filter = MyFilter;
37/// assert_eq!(filter.matches(&()), true);
38/// ```
39pub trait Filter<T>: Clone {
40    /// Whether the service should be executed
41    ///
42    /// If `true`, the service will be executed,  otherwise it will
43    /// fall through to the next service.
44    fn matches(&self, item: &T) -> bool;
45}
46
47/// A Tower layer that executes the provided service only
48/// if the given filter returns true.
49/// Otherwise it falls through to the inner server.
50///
51/// # Example
52/// ```rust
53/// use tower_fallthrough_filter::{Filter, FilterLayer};
54/// use tower::{Service, Layer};
55///
56/// #[derive(Debug, Clone)]
57/// struct MyFilter;
58///
59/// impl Filter<bool> for MyFilter {
60///     fn matches(&self, data: &bool) -> bool {
61///         *data
62///     }
63/// }
64///
65/// #[derive(Debug, Clone)]
66/// struct StringService(String);
67///
68/// impl Service<bool> for StringService {
69///     type Response = String;
70///     type Error = std::convert::Infallible;
71///     type Future = std::future::Ready::<Result<Self::Response, Self::Error>>;
72///
73///     fn poll_ready(
74///         &mut self,
75///         _: &mut std::task::Context<'_>,
76///     ) -> std::task::Poll<Result<(), Self::Error>> {
77///         std::task::Poll::Ready(Ok(()))
78///     }
79///
80///     fn call(&mut self, req: bool) -> Self::Future {
81///         std::future::ready(Ok(self.0.clone()))
82///     }
83/// }
84///
85/// #[tokio::main]
86/// async fn main() {
87///     let service_a = StringService("A".to_string());
88///     let service_b = StringService("B".to_string());
89///     let filter = MyFilter;
90///
91///     let mut middleware = FilterLayer::new(filter, service_a).layer(service_b);
92///
93///     assert_eq!(middleware.call(true).await, Ok("A".to_string()));
94///     assert_eq!(middleware.call(false).await, Ok("B".to_string()));
95/// }
96///
97#[derive(Debug)]
98pub struct FilterLayer<F, S, T, R, E>
99where
100    F: Filter<T>,
101    S: Service<T, Response = R, Error = E>,
102{
103    filter: F,
104    service: S,
105
106    _marker: PhantomData<(T, R, E)>,
107}
108
109// NOTE: This is required to make the `FilterLayer` clonable
110//       as the `PhantomData` might be not clonable.
111impl<F, S, R, E, T> Clone for FilterLayer<F, S, T, R, E>
112where
113    F: Filter<T> + Clone,
114    S: Service<T, Response = R, Error = E> + Clone,
115{
116    fn clone(&self) -> Self {
117        Self {
118            filter: self.filter.clone(),
119            service: self.service.clone(),
120
121            _marker: PhantomData,
122        }
123    }
124}
125
126impl<F: Filter<T>, S: Service<T>, T> FilterLayer<F, S, T, S::Response, S::Error> {
127    /// Creates a new FilterLayer given a `Service` and a `Filter`.
128    ///
129    /// NOTE: The Service and the Filter have to operate on the same
130    /// type `T`.
131    pub fn new(filter: F, service: S) -> Self {
132        Self {
133            filter,
134            service,
135
136            _marker: PhantomData,
137        }
138    }
139}
140
141impl<F, S, I, T, R, E> Layer<I> for FilterLayer<F, S, T, R, E>
142where
143    F: Filter<T> + Clone,
144    S: Service<T, Response = R, Error = E> + Clone,
145    I: Service<T, Response = R, Error = E> + Clone,
146{
147    type Service = FilterService<F, S, I, T, R, E>;
148
149    fn layer(&self, inner_service: I) -> Self::Service {
150        let filter = self.filter.clone();
151        let filtered_service = self.service.clone();
152
153        FilterService {
154            filter,
155            service: filtered_service,
156            inner: inner_service,
157
158            _marker: PhantomData,
159        }
160    }
161}
162
163#[derive(Debug)]
164pub struct FilterService<F, S, I, T, R, E>
165where
166    F: Filter<T>,
167    S: Service<T, Response = R, Error = E>,
168    I: Service<T, Response = R, Error = E>,
169{
170    filter: F,
171    service: S,
172    inner: I,
173
174    _marker: PhantomData<(T, R, E)>,
175}
176
177// NOTE: This is required to make the `FilterService` clonable
178//       as the `PhantomData` might be not clonable.
179impl<F, S, I, T, R, E> Clone for FilterService<F, S, I, T, R, E>
180where
181    F: Filter<T>,
182    S: Service<T, Response = R, Error = E> + Clone,
183    I: Service<T, Response = R, Error = E> + Clone,
184{
185    fn clone(&self) -> Self {
186        Self {
187            filter: self.filter.clone(),
188            service: self.service.clone(),
189            inner: self.inner.clone(),
190
191            _marker: PhantomData,
192        }
193    }
194}
195
196impl<F, S, I, T, R, E> Service<T> for FilterService<F, S, I, T, R, E>
197where
198    F: Filter<T>,
199    S: Service<T, Response = R, Error = E>,
200    S::Future: Send + 'static,
201    I: Service<T, Response = R, Error = E>,
202    I::Future: Send + 'static,
203{
204    type Response = S::Response;
205    type Error = S::Error;
206    type Future = Either<S::Future, I::Future>;
207
208    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209        ready!(self.service.poll_ready(cx))?;
210        // NOTE: It is probably best to poll the `inner_service` here as well
211        //       as otherwise it might be called when it isn't ready yet.
212        ready!(self.inner.poll_ready(cx))?;
213
214        Poll::Ready(Ok(()))
215    }
216
217    fn call(&mut self, req: T) -> Self::Future {
218        if self.filter.matches(&req) {
219            Either::Left(self.service.call(req))
220        } else {
221            Either::Right(self.inner.call(req))
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::test_util::*;
230
231    #[tokio::test]
232    async fn should_allow() {
233        let service_a = TestService("a");
234        let service_b = TestService("b");
235
236        let filter = TestFilter(true);
237        let filter_layer = FilterLayer::new(filter, service_a);
238
239        let mut middleware = filter_layer.layer(service_b);
240
241        assert_eq!(middleware.call(()).await, Ok("a"));
242    }
243
244    #[tokio::test]
245    async fn should_fall_through() {
246        let service_a = TestService("a");
247        let service_b = TestService("b");
248
249        let filter = TestFilter(false);
250        let filter_layer = FilterLayer::new(filter, service_a);
251
252        let mut middleware = filter_layer.layer(service_b);
253
254        assert_eq!(middleware.call(()).await, Ok("b"));
255    }
256}