1use std::marker::PhantomData;
2
3use motore::{Service, layer::Layer};
4
5use crate::{
6 context::ServerContext,
7 request::Request,
8 response::Response,
9 server::{IntoResponse, handler::HandlerWithoutRequest},
10};
11
12#[derive(Clone)]
16pub struct FilterLayer<H, R, T> {
17 handler: H,
18 _marker: PhantomData<(R, T)>,
19}
20
21impl<H, R, T> FilterLayer<H, R, T> {
22 pub fn new(h: H) -> Self {
58 Self {
59 handler: h,
60 _marker: PhantomData,
61 }
62 }
63}
64
65impl<S, H, R, T> Layer<S> for FilterLayer<H, R, T>
66where
67 S: Send + Sync + 'static,
68 H: Clone + Send + Sync + 'static,
69 T: Sync,
70{
71 type Service = Filter<S, H, R, T>;
72
73 fn layer(self, inner: S) -> Self::Service {
74 Filter {
75 service: inner,
76 handler: self.handler,
77 _marker: PhantomData,
78 }
79 }
80}
81
82#[derive(Clone)]
86pub struct Filter<S, H, R, T> {
87 service: S,
88 handler: H,
89 _marker: PhantomData<(R, T)>,
90}
91
92impl<S, B, H, R, T> Service<ServerContext, Request<B>> for Filter<S, H, R, T>
93where
94 S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
95 S::Response: IntoResponse,
96 S::Error: IntoResponse,
97 B: Send,
98 H: HandlerWithoutRequest<T, Result<(), R>> + Clone + Send + Sync + 'static,
99 R: IntoResponse + Send + Sync,
100 T: Sync,
101{
102 type Response = Response;
103 type Error = S::Error;
104
105 async fn call(
106 &self,
107 cx: &mut ServerContext,
108 req: Request<B>,
109 ) -> Result<Self::Response, Self::Error> {
110 let (mut parts, body) = req.into_parts();
111 let res = self.handler.clone().handle(cx, &mut parts).await;
112 let req = Request::from_parts(parts, body);
113 match res {
114 Ok(Ok(())) => self
116 .service
117 .call(cx, req)
118 .await
119 .map(IntoResponse::into_response),
120 Ok(Err(res)) => Ok(res.into_response()),
122 Err(rej) => {
124 tracing::warn!("[Volo-HTTP] FilterLayer: something wrong while extracting");
125 Ok(rej.into_response())
126 }
127 }
128 }
129}
130
131#[cfg(test)]
132mod filter_tests {
133 use http::{Method, StatusCode};
134 use motore::{Service, layer::Layer};
135
136 use crate::{
137 body::BodyConversion,
138 server::{
139 route::{Route, any},
140 test_helpers::empty_cx,
141 },
142 utils::test_helpers::simple_req,
143 };
144
145 #[tokio::test]
146 async fn test_filter_layer() {
147 use crate::server::layer::FilterLayer;
148
149 async fn reject_post(method: Method) -> Result<(), StatusCode> {
150 if method == Method::POST {
151 Err(StatusCode::METHOD_NOT_ALLOWED)
152 } else {
153 Ok(())
154 }
155 }
156
157 async fn handler() -> &'static str {
158 "Hello, World"
159 }
160
161 let filter_layer = FilterLayer::new(reject_post);
162 let route: Route<&str> = Route::new(any(handler));
163 let service = filter_layer.layer(route);
164
165 let mut cx = empty_cx();
166
167 let req = simple_req(Method::GET, "/", "");
169 let resp = service.call(&mut cx, req).await.unwrap();
170 assert_eq!(
171 resp.into_body().into_string().await.unwrap(),
172 "Hello, World"
173 );
174
175 let req = simple_req(Method::POST, "/", "");
177 let resp = service.call(&mut cx, req).await.unwrap();
178 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
179 }
180}