tower_fallthrough_filter/
lib.rs1use std::{
2 marker::PhantomData,
3 task::{Context, Poll},
4};
5
6use ::futures::{future::Either, ready};
7use tower::{Layer, Service};
8
9#[cfg(test)]
10pub mod test_util;
11
12#[cfg(feature = "futures")]
13pub mod futures;
14
15#[cfg(feature = "async")]
16pub use async_feature::{AsyncFilter, AsyncFilterLayer, AsyncFilterService};
17
18#[cfg(feature = "async")]
19mod async_feature;
20
21pub trait Filter<T>: Clone {
40 fn matches(&self, item: &T) -> bool;
45}
46
47#[derive(Debug)]
98pub struct FilterLayer<F, S, T, R, E>
99where
100 F: Filter<T>,
101 S: Service<T, Response = R, Error = E>,
102{
103 filter: F,
104 service: S,
105
106 _marker: PhantomData<(T, R, E)>,
107}
108
109impl<F, S, R, E, T> Clone for FilterLayer<F, S, T, R, E>
112where
113 F: Filter<T> + Clone,
114 S: Service<T, Response = R, Error = E> + Clone,
115{
116 fn clone(&self) -> Self {
117 Self {
118 filter: self.filter.clone(),
119 service: self.service.clone(),
120
121 _marker: PhantomData,
122 }
123 }
124}
125
126impl<F: Filter<T>, S: Service<T>, T> FilterLayer<F, S, T, S::Response, S::Error> {
127 pub fn new(filter: F, service: S) -> Self {
132 Self {
133 filter,
134 service,
135
136 _marker: PhantomData,
137 }
138 }
139}
140
141impl<F, S, I, T, R, E> Layer<I> for FilterLayer<F, S, T, R, E>
142where
143 F: Filter<T> + Clone,
144 S: Service<T, Response = R, Error = E> + Clone,
145 I: Service<T, Response = R, Error = E> + Clone,
146{
147 type Service = FilterService<F, S, I, T, R, E>;
148
149 fn layer(&self, inner_service: I) -> Self::Service {
150 let filter = self.filter.clone();
151 let filtered_service = self.service.clone();
152
153 FilterService {
154 filter,
155 service: filtered_service,
156 inner: inner_service,
157
158 _marker: PhantomData,
159 }
160 }
161}
162
163#[derive(Debug)]
164pub struct FilterService<F, S, I, T, R, E>
165where
166 F: Filter<T>,
167 S: Service<T, Response = R, Error = E>,
168 I: Service<T, Response = R, Error = E>,
169{
170 filter: F,
171 service: S,
172 inner: I,
173
174 _marker: PhantomData<(T, R, E)>,
175}
176
177impl<F, S, I, T, R, E> Clone for FilterService<F, S, I, T, R, E>
180where
181 F: Filter<T>,
182 S: Service<T, Response = R, Error = E> + Clone,
183 I: Service<T, Response = R, Error = E> + Clone,
184{
185 fn clone(&self) -> Self {
186 Self {
187 filter: self.filter.clone(),
188 service: self.service.clone(),
189 inner: self.inner.clone(),
190
191 _marker: PhantomData,
192 }
193 }
194}
195
196impl<F, S, I, T, R, E> Service<T> for FilterService<F, S, I, T, R, E>
197where
198 F: Filter<T>,
199 S: Service<T, Response = R, Error = E>,
200 S::Future: Send + 'static,
201 I: Service<T, Response = R, Error = E>,
202 I::Future: Send + 'static,
203{
204 type Response = S::Response;
205 type Error = S::Error;
206 type Future = Either<S::Future, I::Future>;
207
208 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209 ready!(self.service.poll_ready(cx))?;
210 ready!(self.inner.poll_ready(cx))?;
213
214 Poll::Ready(Ok(()))
215 }
216
217 fn call(&mut self, req: T) -> Self::Future {
218 if self.filter.matches(&req) {
219 Either::Left(self.service.call(req))
220 } else {
221 Either::Right(self.inner.call(req))
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::test_util::*;
230
231 #[tokio::test]
232 async fn should_allow() {
233 let service_a = TestService("a");
234 let service_b = TestService("b");
235
236 let filter = TestFilter(true);
237 let filter_layer = FilterLayer::new(filter, service_a);
238
239 let mut middleware = filter_layer.layer(service_b);
240
241 assert_eq!(middleware.call(()).await, Ok("a"));
242 }
243
244 #[tokio::test]
245 async fn should_fall_through() {
246 let service_a = TestService("a");
247 let service_b = TestService("b");
248
249 let filter = TestFilter(false);
250 let filter_layer = FilterLayer::new(filter, service_a);
251
252 let mut middleware = filter_layer.layer(service_b);
253
254 assert_eq!(middleware.call(()).await, Ok("b"));
255 }
256}