volo_http/server/layer/
filter.rs

1use std::marker::PhantomData;
2
3use motore::{Service, layer::Layer};
4
5use crate::{
6    context::ServerContext,
7    request::Request,
8    response::Response,
9    server::{IntoResponse, handler::HandlerWithoutRequest},
10};
11
12/// [`Layer`] for filtering requests
13///
14/// See [`FilterLayer::new`] for more details.
15#[derive(Clone)]
16pub struct FilterLayer<H, R, T> {
17    handler: H,
18    _marker: PhantomData<(R, T)>,
19}
20
21impl<H, R, T> FilterLayer<H, R, T> {
22    /// Create a new [`FilterLayer`]
23    ///
24    /// The `handler` is an async function with some params that implement
25    /// [`FromContext`](crate::server::extract::FromContext), and returns
26    /// `Result<(), impl IntoResponse>`.
27    ///
28    /// If the handler returns `Ok(())`, the request will proceed. However, if the handler returns
29    /// `Err` with an object that implements [`IntoResponse`], the request will be rejected with
30    /// the returned object as the response.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// use http::{method::Method, status::StatusCode};
36    /// use volo_http::server::{
37    ///     layer::FilterLayer,
38    ///     route::{Router, get},
39    /// };
40    ///
41    /// async fn reject_post(method: Method) -> Result<(), StatusCode> {
42    ///     if method == Method::POST {
43    ///         Err(StatusCode::METHOD_NOT_ALLOWED)
44    ///     } else {
45    ///         Ok(())
46    ///     }
47    /// }
48    ///
49    /// async fn handler() -> &'static str {
50    ///     "Hello, World"
51    /// }
52    ///
53    /// let router: Router = Router::new()
54    ///     .route("/", get(handler))
55    ///     .layer(FilterLayer::new(reject_post));
56    /// ```
57    pub fn new(h: H) -> Self {
58        Self {
59            handler: h,
60            _marker: PhantomData,
61        }
62    }
63}
64
65impl<S, H, R, T> Layer<S> for FilterLayer<H, R, T>
66where
67    S: Send + Sync + 'static,
68    H: Clone + Send + Sync + 'static,
69    T: Sync,
70{
71    type Service = Filter<S, H, R, T>;
72
73    fn layer(self, inner: S) -> Self::Service {
74        Filter {
75            service: inner,
76            handler: self.handler,
77            _marker: PhantomData,
78        }
79    }
80}
81
82/// [`FilterLayer`] generated [`Service`]
83///
84/// See [`FilterLayer`] for more details.
85#[derive(Clone)]
86pub struct Filter<S, H, R, T> {
87    service: S,
88    handler: H,
89    _marker: PhantomData<(R, T)>,
90}
91
92impl<S, B, H, R, T> Service<ServerContext, Request<B>> for Filter<S, H, R, T>
93where
94    S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
95    S::Response: IntoResponse,
96    S::Error: IntoResponse,
97    B: Send,
98    H: HandlerWithoutRequest<T, Result<(), R>> + Clone + Send + Sync + 'static,
99    R: IntoResponse + Send + Sync,
100    T: Sync,
101{
102    type Response = Response;
103    type Error = S::Error;
104
105    async fn call(
106        &self,
107        cx: &mut ServerContext,
108        req: Request<B>,
109    ) -> Result<Self::Response, Self::Error> {
110        let (mut parts, body) = req.into_parts();
111        let res = self.handler.clone().handle(cx, &mut parts).await;
112        let req = Request::from_parts(parts, body);
113        match res {
114            // do not filter it, call the service
115            Ok(Ok(())) => self
116                .service
117                .call(cx, req)
118                .await
119                .map(IntoResponse::into_response),
120            // filter it and return the specified response
121            Ok(Err(res)) => Ok(res.into_response()),
122            // something wrong while extracting
123            Err(rej) => {
124                tracing::warn!("[Volo-HTTP] FilterLayer: something wrong while extracting");
125                Ok(rej.into_response())
126            }
127        }
128    }
129}
130
131#[cfg(test)]
132mod filter_tests {
133    use http::{Method, StatusCode};
134    use motore::{Service, layer::Layer};
135
136    use crate::{
137        body::BodyConversion,
138        server::{
139            route::{Route, any},
140            test_helpers::empty_cx,
141        },
142        utils::test_helpers::simple_req,
143    };
144
145    #[tokio::test]
146    async fn test_filter_layer() {
147        use crate::server::layer::FilterLayer;
148
149        async fn reject_post(method: Method) -> Result<(), StatusCode> {
150            if method == Method::POST {
151                Err(StatusCode::METHOD_NOT_ALLOWED)
152            } else {
153                Ok(())
154            }
155        }
156
157        async fn handler() -> &'static str {
158            "Hello, World"
159        }
160
161        let filter_layer = FilterLayer::new(reject_post);
162        let route: Route<&str> = Route::new(any(handler));
163        let service = filter_layer.layer(route);
164
165        let mut cx = empty_cx();
166
167        // Test case 1: not filter
168        let req = simple_req(Method::GET, "/", "");
169        let resp = service.call(&mut cx, req).await.unwrap();
170        assert_eq!(
171            resp.into_body().into_string().await.unwrap(),
172            "Hello, World"
173        );
174
175        // Test case 2: filter
176        let req = simple_req(Method::POST, "/", "");
177        let resp = service.call(&mut cx, req).await.unwrap();
178        assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
179    }
180}