sfo_http/http_server/
middleware.rs

1//! Middleware types.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use std::future::Future;
7use std::pin::Pin;
8use http::StatusCode;
9use route_recognizer::nfa::State;
10use crate::errors::HttpResult;
11use crate::http_server::endpoint::DynEndpoint;
12use super::{Endpoint, Request, Response,};
13
14/// Middleware that wraps around the remaining middleware chain.
15#[async_trait]
16pub trait Middleware<Req: Request, Resp: Response>: Send + Sync + 'static {
17    /// Asynchronously handle the request, and return a response.
18    async fn handle(&self, request: Req, next: Next<'_, Req, Resp>) -> HttpResult<Resp>;
19
20    /// Set the middleware's name. By default it uses the type signature.
21    fn name(&self) -> &str {
22        std::any::type_name::<Self>()
23    }
24}
25
26#[async_trait]
27impl<Req, Resp, F> Middleware<Req, Resp> for F
28where
29    Req: Request,
30    Resp: Response,
31    F: Send
32        + Sync
33        + 'static
34        + for<'a> Fn(
35            Req,
36            Next<'a, Req, Resp>,
37        ) -> Pin<Box<dyn Future<Output = HttpResult<Resp>> + 'a + Send>>,
38{
39    async fn handle(&self, req: Req, next: Next<'_, Req, Resp>) -> HttpResult<Resp> {
40        (self)(req, next).await
41    }
42}
43
44/// The remainder of a middleware chain, including the endpoint.
45#[allow(missing_debug_implementations)]
46pub struct Next<'a, Req: Request, Resp: Response> {
47    pub(crate) endpoint: &'a DynEndpoint<Req, Resp>,
48    pub(crate) next_middleware: &'a [Arc<dyn Middleware< Req, Resp>>],
49}
50
51impl<Req: Request, Resp: Response> Next<'_, Req, Resp> {
52    /// Asynchronously execute the remaining middleware chain.
53    pub async fn run(mut self, req: Req) -> Resp {
54        if let Some((current, next)) = self.next_middleware.split_first() {
55            self.next_middleware = next;
56            match current.handle(req, self).await {
57                Ok(request) => request,
58                Err(err) => {
59                    log::error!("middleware handle err: {}", err);
60                    Resp::new(StatusCode::INTERNAL_SERVER_ERROR)
61                },
62            }
63        } else {
64            match self.endpoint.call(req).await {
65                Ok(request) => request,
66                Err(err) => {
67                    log::error!("endpoint call err: {}", err);
68                    Resp::new(StatusCode::INTERNAL_SERVER_ERROR)
69                },
70            }
71        }
72    }
73}