Skip to main content

rustbasic_core/middleware/
mod.rs

1pub mod logging;
2pub mod security_headers;
3pub mod cors;
4pub mod csrf;
5
6use std::sync::Arc;
7use std::pin::Pin;
8use crate::requests::Request;
9use crate::router::{Response, ErasedHandler};
10
11pub type MiddlewareFn = Arc<
12    dyn Fn(Request, Next) -> Pin<Box<dyn std::future::Future<Output = Response> + Send>>
13        + Send
14        + Sync,
15>;
16
17pub struct Next {
18    pub(crate) chain: Arc<MiddlewareChain>,
19}
20
21impl Next {
22    pub async fn run(self, req: Request) -> Response {
23        self.chain.next(req).await
24    }
25}
26
27pub enum MiddlewareChain {
28    Next(MiddlewareFn, Arc<MiddlewareChain>),
29    End(Arc<dyn ErasedHandler>),
30}
31
32impl MiddlewareChain {
33    pub async fn next(self: Arc<Self>, req: Request) -> Response {
34        match &*self {
35            Self::Next(mw, next_chain) => {
36                let next = Next { chain: next_chain.clone() };
37                mw(req, next).await
38            }
39            Self::End(handler) => {
40                handler.call(req).await
41            }
42        }
43    }
44}
45
46pub fn from_fn<F, Fut>(mw: F) -> MiddlewareFn
47where
48    F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
49    Fut: std::future::Future<Output = Response> + Send + 'static,
50{
51    Arc::new(move |req, next| Box::pin(mw(req, next)))
52}