salvo_core/
service.rs

1use std::fmt::{self, Debug, Formatter};
2use std::pin::Pin;
3use std::sync::Arc;
4
5use headers::HeaderValue;
6use http::header::{ALT_SVC, CONTENT_TYPE};
7use http::uri::Scheme;
8use hyper::service::Service as HyperService;
9use hyper::{Method, Request as HyperRequest, Response as HyperResponse};
10
11use crate::catcher::{Catcher, write_error_default};
12use crate::conn::SocketAddr;
13use crate::fuse::ArcFusewire;
14use crate::handler::{Handler, WhenHoop};
15use crate::http::body::{ReqBody, ResBody};
16use crate::http::{Mime, Request, Response, StatusCode};
17use crate::routing::{FlowCtrl, PathState, Router};
18use crate::{Depot, async_trait};
19
20/// Service http request.
21#[non_exhaustive]
22pub struct Service {
23    /// The router of this service.
24    pub router: Arc<Router>,
25    /// The catcher of this service.
26    pub catcher: Option<Arc<Catcher>>,
27    /// These hoops will always be called when request received.
28    pub hoops: Vec<Arc<dyn Handler>>,
29    /// The allowed media types of this service.
30    pub allowed_media_types: Arc<Vec<Mime>>,
31}
32
33impl Debug for Service {
34    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35        f.debug_struct("Service")
36            .field("router", &self.router)
37            .field("catcher", &self.catcher)
38            .field("hoops", &self.hoops.len())
39            .field("allowed_media_types", &self.allowed_media_types.len())
40            .finish()
41    }
42}
43
44impl Service {
45    /// Create a new Service with a [`Router`].
46    #[inline]
47    pub fn new<T>(router: T) -> Self
48    where
49        T: Into<Arc<Router>>,
50    {
51        Self {
52            router: router.into(),
53            catcher: None,
54            hoops: vec![],
55            allowed_media_types: Arc::new(vec![]),
56        }
57    }
58
59    /// Get router in this `Service`.
60    #[inline]
61    #[must_use]
62    pub fn router(&self) -> Arc<Router> {
63        self.router.clone()
64    }
65
66    /// When the response code is 400-600 and the body is empty, capture and set the error page content.
67    /// If catchers is not set, the default error page will be used.
68    ///
69    /// # Example
70    ///
71    /// ```
72    /// use salvo_core::prelude::*;
73    /// use salvo_core::catcher::Catcher;
74    ///
75    /// #[handler]
76    /// async fn handle404(&self, _req: &Request, _depot: &Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
77    ///     if let Some(StatusCode::NOT_FOUND) = res.status_code {
78    ///         res.render("Custom 404 Error Page");
79    ///         ctrl.skip_rest();
80    ///     }
81    /// }
82    ///
83    /// #[tokio::main]
84    /// async fn main() {
85    ///     Service::new(Router::new()).catcher(Catcher::default().hoop(handle404));
86    /// }
87    /// ```
88    #[inline]
89    #[must_use]
90    pub fn catcher(mut self, catcher: impl Into<Arc<Catcher>>) -> Self {
91        self.catcher = Some(catcher.into());
92        self
93    }
94
95    /// Add a handler as middleware, it will run the handler when request received.
96    #[inline]
97    #[must_use]
98    pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
99        self.hoops.push(Arc::new(hoop));
100        self
101    }
102
103    /// Add a handler as middleware, it will run the handler when request received.
104    ///
105    /// This middleware is only effective when the filter returns true..
106    #[inline]
107    #[must_use]
108    pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
109    where
110        H: Handler,
111        F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
112    {
113        self.hoops.push(Arc::new(WhenHoop {
114            inner: hoop,
115            filter,
116        }));
117        self
118    }
119
120    /// Sets allowed media types list and returns `Self` for write code chained.
121    ///
122    /// # Example
123    ///
124    /// ```
125    /// # use salvo_core::prelude::*;
126    ///
127    /// # #[tokio::main]
128    /// # async fn main() {
129    /// let service = Service::new(Router::new()).allowed_media_types(vec![mime::TEXT_PLAIN]);
130    /// # }
131    /// ```
132    #[inline]
133    #[must_use]
134    pub fn allowed_media_types<T>(mut self, allowed_media_types: T) -> Self
135    where
136        T: Into<Arc<Vec<Mime>>>,
137    {
138        self.allowed_media_types = allowed_media_types.into();
139        self
140    }
141
142    #[doc(hidden)]
143    #[inline]
144    #[must_use]
145    pub fn hyper_handler(
146        &self,
147        local_addr: SocketAddr,
148        remote_addr: SocketAddr,
149        http_scheme: Scheme,
150        fusewire: Option<ArcFusewire>,
151        alt_svc_h3: Option<HeaderValue>,
152    ) -> HyperHandler {
153        HyperHandler {
154            local_addr,
155            remote_addr,
156            http_scheme,
157            router: self.router.clone(),
158            catcher: self.catcher.clone(),
159            hoops: self.hoops.clone(),
160            allowed_media_types: self.allowed_media_types.clone(),
161            fusewire,
162            alt_svc_h3,
163        }
164    }
165    /// Handle new request, this function only used for test.
166    #[cfg(feature = "test")]
167    #[inline]
168    pub async fn handle(&self, request: impl Into<Request> + Send) -> Response {
169        let request = request.into();
170        self.hyper_handler(
171            request.local_addr.clone(),
172            request.remote_addr.clone(),
173            request.scheme.clone(),
174            None,
175            None,
176        )
177        .handle(request)
178        .await
179    }
180}
181
182impl<T> From<T> for Service
183where
184    T: Into<Arc<Router>>,
185{
186    #[inline]
187    fn from(router: T) -> Self {
188        Self::new(router)
189    }
190}
191
192struct DefaultStatusOK;
193#[async_trait]
194impl Handler for DefaultStatusOK {
195    async fn handle(
196        &self,
197        req: &mut Request,
198        depot: &mut Depot,
199        res: &mut Response,
200        ctrl: &mut FlowCtrl,
201    ) {
202        ctrl.call_next(req, depot, res).await;
203        if res.status_code.is_none() {
204            res.status_code = Some(StatusCode::OK);
205        }
206    }
207}
208
209#[doc(hidden)]
210#[derive(Clone)]
211pub struct HyperHandler {
212    pub(crate) local_addr: SocketAddr,
213    pub(crate) remote_addr: SocketAddr,
214    pub(crate) http_scheme: Scheme,
215    pub(crate) router: Arc<Router>,
216    pub(crate) catcher: Option<Arc<Catcher>>,
217    pub(crate) hoops: Vec<Arc<dyn Handler>>,
218    pub(crate) allowed_media_types: Arc<Vec<Mime>>,
219    pub(crate) fusewire: Option<ArcFusewire>,
220    pub(crate) alt_svc_h3: Option<HeaderValue>,
221}
222impl Debug for HyperHandler {
223    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
224        f.debug_struct("HyperHandler")
225            .field("local_addr", &self.local_addr)
226            .field("remote_addr", &self.remote_addr)
227            .field("http_scheme", &self.http_scheme)
228            .field("router", &self.router)
229            .field("catcher", &self.catcher)
230            .field("allowed_media_types", &self.allowed_media_types)
231            .field("alt_svc_h3", &self.alt_svc_h3)
232            .finish()
233    }
234}
235impl HyperHandler {
236    /// Handle [`Request`] and returns [`Response`].
237    pub fn handle(&self, mut req: Request) -> impl Future<Output = Response> + 'static {
238        let catcher = self.catcher.clone();
239        let allowed_media_types = self.allowed_media_types.clone();
240        req.local_addr = self.local_addr.clone();
241        req.remote_addr = self.remote_addr.clone();
242        #[cfg(not(feature = "cookie"))]
243        let mut res = Response::new();
244        #[cfg(feature = "cookie")]
245        let mut res = Response::with_cookies(req.cookies.clone());
246        if let Some(alt_svc_h3) = &self.alt_svc_h3 {
247            if !res.headers().contains_key(ALT_SVC) {
248                res.headers_mut().insert(ALT_SVC, alt_svc_h3.clone());
249            }
250        }
251        let mut depot = Depot::new();
252        let mut path_state = PathState::new(req.uri().path());
253        let router = self.router.clone();
254
255        let hoops = self.hoops.clone();
256        async move {
257            if let Some(dm) = router.detect(&mut req, &mut path_state).await {
258                req.params = path_state.params;
259                #[cfg(feature = "matched-path")]
260                {
261                    req.matched_path = path_state.matched_parts.join("/");
262                }
263                // Set default status code before service hoops executed.
264                // We hope all hoops in service can get the correct status code.
265                let mut ctrl = FlowCtrl::new(
266                    [
267                        &hoops[..],
268                        &dm.hoops[..],
269                        &[Arc::new(DefaultStatusOK)],
270                        &[dm.goal],
271                    ]
272                    .concat(),
273                );
274                ctrl.call_next(&mut req, &mut depot, &mut res).await;
275                // Set it to default status code again if any hoop set status code to None.
276                if res.status_code.is_none() {
277                    res.status_code = Some(StatusCode::OK);
278                }
279            } else if !hoops.is_empty() {
280                req.params = path_state.params;
281                // Set default status code before service hoops executed.
282                // We hope all hoops in service can get the correct status code.
283                if path_state.once_ended {
284                    res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
285                } else {
286                    res.status_code = Some(StatusCode::NOT_FOUND);
287                }
288                let mut ctrl = FlowCtrl::new(hoops);
289                ctrl.call_next(&mut req, &mut depot, &mut res).await;
290                // Set it to default status code again if any hoop set status code to None.
291                if res.status_code.is_none() && path_state.once_ended {
292                    res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
293                }
294            } else if path_state.once_ended {
295                res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
296            }
297
298            let status_code = if let Some(status_code) = res.status_code {
299                status_code
300            } else {
301                res.status_code = Some(StatusCode::NOT_FOUND);
302                StatusCode::NOT_FOUND
303            };
304            if !allowed_media_types.is_empty() {
305                if let Some(ctype) = res
306                    .headers()
307                    .get(CONTENT_TYPE)
308                    .and_then(|c| c.to_str().ok())
309                    .and_then(|c| c.parse::<Mime>().ok())
310                {
311                    let mut is_allowed = false;
312                    for mime in &*allowed_media_types {
313                        if mime.type_() == ctype.type_() && mime.subtype() == ctype.subtype() {
314                            is_allowed = true;
315                            break;
316                        }
317                    }
318                    if !is_allowed {
319                        res.status_code(StatusCode::UNSUPPORTED_MEDIA_TYPE);
320                    }
321                }
322            }
323            let has_error = status_code.is_client_error() || status_code.is_server_error();
324            if res.body.is_none()
325                && !has_error
326                && !status_code.is_redirection()
327                && status_code != StatusCode::NO_CONTENT
328                && status_code != StatusCode::SWITCHING_PROTOCOLS
329                && [Method::GET, Method::POST, Method::PATCH, Method::PUT].contains(req.method())
330            {
331                // check for avoid warning when errors (404 etc.)
332                tracing::warn!(
333                    uri = ?req.uri(),
334                    method = req.method().as_str(),
335                    "http response content type header not set"
336                );
337            }
338            if Method::HEAD != *req.method()
339                && (res.body.is_none() || res.body.is_error())
340                && has_error
341            {
342                if let Some(catcher) = catcher {
343                    catcher.catch(&mut req, &mut depot, &mut res).await;
344                } else {
345                    write_error_default(&req, &mut res, None);
346                }
347            }
348            #[cfg(debug_assertions)]
349            if Method::HEAD == *req.method() && !res.body.is_none() {
350                tracing::warn!(
351                    "request with head method should not have body: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/HEAD"
352                );
353            }
354            #[cfg(feature = "quinn")]
355            {
356                use bytes::Bytes;
357                use std::sync::Mutex;
358                if let Some(session) =
359                    req.extensions.remove::<Arc<
360                        crate::proto::WebTransportSession<salvo_http3::quinn::Connection, Bytes>,
361                    >>()
362                {
363                    res.extensions.insert(session);
364                }
365                if let Some(conn) = req.extensions.remove::<Arc<
366                    Mutex<salvo_http3::server::Connection<salvo_http3::quinn::Connection, Bytes>>,
367                >>() {
368                    res.extensions.insert(conn);
369                }
370                if let Some(stream) = req.extensions.remove::<Arc<
371                    salvo_http3::server::RequestStream<
372                        salvo_http3::quinn::BidiStream<Bytes>,
373                        Bytes,
374                    >,
375                >>() {
376                    res.extensions.insert(stream);
377                }
378            }
379            res
380        }
381    }
382}
383
384impl<B> HyperService<HyperRequest<B>> for HyperHandler
385where
386    B: Into<ReqBody>,
387{
388    type Response = HyperResponse<ResBody>;
389    type Error = hyper::Error;
390    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
391
392    #[inline]
393    fn call(
394        &self,
395        #[cfg(not(feature = "fix-http1-request-uri"))] req: HyperRequest<B>,
396        #[cfg(feature = "fix-http1-request-uri")] mut req: HyperRequest<B>,
397    ) -> Self::Future {
398        let scheme = req
399            .uri()
400            .scheme()
401            .cloned()
402            .unwrap_or_else(|| self.http_scheme.clone());
403        // https://github.com/hyperium/hyper/issues/1310
404        #[cfg(feature = "fix-http1-request-uri")]
405        if req.uri().scheme().is_none() {
406            if let Some(host) = req
407                .headers()
408                .get(http::header::HOST)
409                .and_then(|host| host.to_str().ok())
410                .and_then(|host| host.parse::<http::uri::Authority>().ok())
411            {
412                let mut uri_parts = std::mem::take(req.uri_mut()).into_parts();
413                uri_parts.scheme = Some(scheme.clone());
414                uri_parts.authority = Some(host);
415                if let Ok(uri) = http::uri::Uri::from_parts(uri_parts) {
416                    *req.uri_mut() = uri;
417                }
418            }
419        }
420        let mut request = Request::from_hyper(req, scheme);
421        request.body.set_fusewire(self.fusewire.clone());
422        let response = self.handle(request);
423        Box::pin(async move { Ok(response.await.into_hyper()) })
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use crate::prelude::*;
430    use crate::test::{ResponseExt, TestClient};
431
432    #[tokio::test]
433    async fn test_service() {
434        #[handler]
435        async fn before1(
436            req: &mut Request,
437            depot: &mut Depot,
438            res: &mut Response,
439            ctrl: &mut FlowCtrl,
440        ) {
441            res.render(Text::Plain("before1"));
442            if req.query::<String>("b").unwrap_or_default() == "1" {
443                ctrl.skip_rest();
444            } else {
445                ctrl.call_next(req, depot, res).await;
446            }
447        }
448        #[handler]
449        async fn before2(
450            req: &mut Request,
451            depot: &mut Depot,
452            res: &mut Response,
453            ctrl: &mut FlowCtrl,
454        ) {
455            res.render(Text::Plain("before2"));
456            if req.query::<String>("b").unwrap_or_default() == "2" {
457                ctrl.skip_rest();
458            } else {
459                ctrl.call_next(req, depot, res).await;
460            }
461        }
462        #[handler]
463        async fn before3(
464            req: &mut Request,
465            depot: &mut Depot,
466            res: &mut Response,
467            ctrl: &mut FlowCtrl,
468        ) {
469            res.render(Text::Plain("before3"));
470            if req.query::<String>("b").unwrap_or_default() == "3" {
471                ctrl.skip_rest();
472            } else {
473                ctrl.call_next(req, depot, res).await;
474            }
475        }
476        #[handler]
477        async fn hello() -> Result<&'static str, ()> {
478            Ok("hello")
479        }
480        let router = Router::with_path("level1").hoop(before1).push(
481            Router::with_hoop(before2)
482                .path("level2")
483                .push(Router::with_hoop(before3).path("hello").goal(hello)),
484        );
485        let service = Service::new(router);
486
487        async fn access(service: &Service, b: &str) -> String {
488            TestClient::get(format!("http://127.0.0.1:5801/level1/level2/hello?b={b}"))
489                .send(service)
490                .await
491                .take_string()
492                .await
493                .unwrap()
494        }
495        let content = access(&service, "").await;
496        assert_eq!(content, "before1before2before3hello");
497        let content = access(&service, "1").await;
498        assert_eq!(content, "before1");
499        let content = access(&service, "2").await;
500        assert_eq!(content, "before1before2");
501        let content = access(&service, "3").await;
502        assert_eq!(content, "before1before2before3");
503    }
504
505    #[tokio::test]
506    async fn test_service_405_or_404_error() {
507        #[handler]
508        async fn login() -> &'static str {
509            "login"
510        }
511        #[handler]
512        async fn hello() -> &'static str {
513            "hello"
514        }
515        let router = Router::new()
516            .push(Router::with_path("hello").goal(hello))
517            .push(
518                Router::with_path("login")
519                    .post(login)
520                    .push(Router::with_path("user").get(login)),
521            );
522        let service = Service::new(router);
523
524        let res = TestClient::get("http://127.0.0.1:5801/hello")
525            .send(&service)
526            .await;
527        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
528        let res = TestClient::put("http://127.0.0.1:5801/hello")
529            .send(&service)
530            .await;
531        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
532
533        let res = TestClient::post("http://127.0.0.1:5801/login")
534            .send(&service)
535            .await;
536        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
537
538        let res = TestClient::get("http://127.0.0.1:5801/login")
539            .send(&service)
540            .await;
541        assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
542
543        let res = TestClient::get("http://127.0.0.1:5801/login2")
544            .send(&service)
545            .await;
546        assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
547
548        let res = TestClient::get("http://127.0.0.1:5801/login/user")
549            .send(&service)
550            .await;
551        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
552
553        let res = TestClient::post("http://127.0.0.1:5801/login/user")
554            .send(&service)
555            .await;
556        assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
557
558        let res = TestClient::post("http://127.0.0.1:5801/login/user1")
559            .send(&service)
560            .await;
561        assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
562    }
563}