1use std::fmt::{self, Debug, Formatter};
2use std::pin::Pin;
3use std::sync::Arc;
4
5use headers::HeaderValue;
6use http::header::{ALT_SVC, CONTENT_TYPE};
7use http::uri::Scheme;
8use hyper::service::Service as HyperService;
9use hyper::{Method, Request as HyperRequest, Response as HyperResponse};
10
11use crate::catcher::{Catcher, write_error_default};
12use crate::conn::SocketAddr;
13use crate::fuse::ArcFusewire;
14use crate::handler::{Handler, WhenHoop};
15use crate::http::body::{ReqBody, ResBody};
16use crate::http::{Mime, Request, Response, StatusCode};
17use crate::routing::{FlowCtrl, PathState, Router};
18use crate::{Depot, async_trait};
19
20#[non_exhaustive]
22pub struct Service {
23 pub router: Arc<Router>,
25 pub catcher: Option<Arc<Catcher>>,
27 pub hoops: Vec<Arc<dyn Handler>>,
29 pub allowed_media_types: Arc<Vec<Mime>>,
31}
32
33impl Debug for Service {
34 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35 f.debug_struct("Service")
36 .field("router", &self.router)
37 .field("catcher", &self.catcher)
38 .field("hoops", &self.hoops.len())
39 .field("allowed_media_types", &self.allowed_media_types.len())
40 .finish()
41 }
42}
43
44impl Service {
45 #[inline]
47 pub fn new<T>(router: T) -> Self
48 where
49 T: Into<Arc<Router>>,
50 {
51 Self {
52 router: router.into(),
53 catcher: None,
54 hoops: vec![],
55 allowed_media_types: Arc::new(vec![]),
56 }
57 }
58
59 #[inline]
61 #[must_use]
62 pub fn router(&self) -> Arc<Router> {
63 self.router.clone()
64 }
65
66 #[inline]
89 #[must_use]
90 pub fn catcher(mut self, catcher: impl Into<Arc<Catcher>>) -> Self {
91 self.catcher = Some(catcher.into());
92 self
93 }
94
95 #[inline]
97 #[must_use]
98 pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
99 self.hoops.push(Arc::new(hoop));
100 self
101 }
102
103 #[inline]
107 #[must_use]
108 pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
109 where
110 H: Handler,
111 F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
112 {
113 self.hoops.push(Arc::new(WhenHoop {
114 inner: hoop,
115 filter,
116 }));
117 self
118 }
119
120 #[inline]
133 #[must_use]
134 pub fn allowed_media_types<T>(mut self, allowed_media_types: T) -> Self
135 where
136 T: Into<Arc<Vec<Mime>>>,
137 {
138 self.allowed_media_types = allowed_media_types.into();
139 self
140 }
141
142 #[doc(hidden)]
143 #[inline]
144 #[must_use]
145 pub fn hyper_handler(
146 &self,
147 local_addr: SocketAddr,
148 remote_addr: SocketAddr,
149 http_scheme: Scheme,
150 fusewire: Option<ArcFusewire>,
151 alt_svc_h3: Option<HeaderValue>,
152 ) -> HyperHandler {
153 HyperHandler {
154 local_addr,
155 remote_addr,
156 http_scheme,
157 router: self.router.clone(),
158 catcher: self.catcher.clone(),
159 hoops: self.hoops.clone(),
160 allowed_media_types: self.allowed_media_types.clone(),
161 fusewire,
162 alt_svc_h3,
163 }
164 }
165 #[cfg(feature = "test")]
167 #[inline]
168 pub async fn handle(&self, request: impl Into<Request> + Send) -> Response {
169 let request = request.into();
170 self.hyper_handler(
171 request.local_addr.clone(),
172 request.remote_addr.clone(),
173 request.scheme.clone(),
174 None,
175 None,
176 )
177 .handle(request)
178 .await
179 }
180}
181
182impl<T> From<T> for Service
183where
184 T: Into<Arc<Router>>,
185{
186 #[inline]
187 fn from(router: T) -> Self {
188 Self::new(router)
189 }
190}
191
192struct DefaultStatusOK;
193#[async_trait]
194impl Handler for DefaultStatusOK {
195 async fn handle(
196 &self,
197 req: &mut Request,
198 depot: &mut Depot,
199 res: &mut Response,
200 ctrl: &mut FlowCtrl,
201 ) {
202 ctrl.call_next(req, depot, res).await;
203 if res.status_code.is_none() {
204 res.status_code = Some(StatusCode::OK);
205 }
206 }
207}
208
209#[doc(hidden)]
210#[derive(Clone)]
211pub struct HyperHandler {
212 pub(crate) local_addr: SocketAddr,
213 pub(crate) remote_addr: SocketAddr,
214 pub(crate) http_scheme: Scheme,
215 pub(crate) router: Arc<Router>,
216 pub(crate) catcher: Option<Arc<Catcher>>,
217 pub(crate) hoops: Vec<Arc<dyn Handler>>,
218 pub(crate) allowed_media_types: Arc<Vec<Mime>>,
219 pub(crate) fusewire: Option<ArcFusewire>,
220 pub(crate) alt_svc_h3: Option<HeaderValue>,
221}
222impl Debug for HyperHandler {
223 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
224 f.debug_struct("HyperHandler")
225 .field("local_addr", &self.local_addr)
226 .field("remote_addr", &self.remote_addr)
227 .field("http_scheme", &self.http_scheme)
228 .field("router", &self.router)
229 .field("catcher", &self.catcher)
230 .field("allowed_media_types", &self.allowed_media_types)
231 .field("alt_svc_h3", &self.alt_svc_h3)
232 .finish()
233 }
234}
235impl HyperHandler {
236 pub fn handle(&self, mut req: Request) -> impl Future<Output = Response> + 'static {
238 let catcher = self.catcher.clone();
239 let allowed_media_types = self.allowed_media_types.clone();
240 req.local_addr = self.local_addr.clone();
241 req.remote_addr = self.remote_addr.clone();
242 #[cfg(not(feature = "cookie"))]
243 let mut res = Response::new();
244 #[cfg(feature = "cookie")]
245 let mut res = Response::with_cookies(req.cookies.clone());
246 if let Some(alt_svc_h3) = &self.alt_svc_h3 {
247 if !res.headers().contains_key(ALT_SVC) {
248 res.headers_mut().insert(ALT_SVC, alt_svc_h3.clone());
249 }
250 }
251 let mut depot = Depot::new();
252 let mut path_state = PathState::new(req.uri().path());
253 let router = self.router.clone();
254
255 let hoops = self.hoops.clone();
256 async move {
257 if let Some(dm) = router.detect(&mut req, &mut path_state).await {
258 req.params = path_state.params;
259 #[cfg(feature = "matched-path")]
260 {
261 req.matched_path = path_state.matched_parts.join("/");
262 }
263 let mut ctrl = FlowCtrl::new(
266 [
267 &hoops[..],
268 &dm.hoops[..],
269 &[Arc::new(DefaultStatusOK)],
270 &[dm.goal],
271 ]
272 .concat(),
273 );
274 ctrl.call_next(&mut req, &mut depot, &mut res).await;
275 if res.status_code.is_none() {
277 res.status_code = Some(StatusCode::OK);
278 }
279 } else if !hoops.is_empty() {
280 req.params = path_state.params;
281 if path_state.once_ended {
284 res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
285 } else {
286 res.status_code = Some(StatusCode::NOT_FOUND);
287 }
288 let mut ctrl = FlowCtrl::new(hoops);
289 ctrl.call_next(&mut req, &mut depot, &mut res).await;
290 if res.status_code.is_none() && path_state.once_ended {
292 res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
293 }
294 } else if path_state.once_ended {
295 res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
296 }
297
298 let status_code = if let Some(status_code) = res.status_code {
299 status_code
300 } else {
301 res.status_code = Some(StatusCode::NOT_FOUND);
302 StatusCode::NOT_FOUND
303 };
304 if !allowed_media_types.is_empty() {
305 if let Some(ctype) = res
306 .headers()
307 .get(CONTENT_TYPE)
308 .and_then(|c| c.to_str().ok())
309 .and_then(|c| c.parse::<Mime>().ok())
310 {
311 let mut is_allowed = false;
312 for mime in &*allowed_media_types {
313 if mime.type_() == ctype.type_() && mime.subtype() == ctype.subtype() {
314 is_allowed = true;
315 break;
316 }
317 }
318 if !is_allowed {
319 res.status_code(StatusCode::UNSUPPORTED_MEDIA_TYPE);
320 }
321 }
322 }
323 let has_error = status_code.is_client_error() || status_code.is_server_error();
324 if res.body.is_none()
325 && !has_error
326 && !status_code.is_redirection()
327 && status_code != StatusCode::NO_CONTENT
328 && status_code != StatusCode::SWITCHING_PROTOCOLS
329 && [Method::GET, Method::POST, Method::PATCH, Method::PUT].contains(req.method())
330 {
331 tracing::warn!(
333 uri = ?req.uri(),
334 method = req.method().as_str(),
335 "http response content type header not set"
336 );
337 }
338 if Method::HEAD != *req.method()
339 && (res.body.is_none() || res.body.is_error())
340 && has_error
341 {
342 if let Some(catcher) = catcher {
343 catcher.catch(&mut req, &mut depot, &mut res).await;
344 } else {
345 write_error_default(&req, &mut res, None);
346 }
347 }
348 #[cfg(debug_assertions)]
349 if Method::HEAD == *req.method() && !res.body.is_none() {
350 tracing::warn!(
351 "request with head method should not have body: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/HEAD"
352 );
353 }
354 #[cfg(feature = "quinn")]
355 {
356 use bytes::Bytes;
357 use std::sync::Mutex;
358 if let Some(session) =
359 req.extensions.remove::<Arc<
360 crate::proto::WebTransportSession<salvo_http3::quinn::Connection, Bytes>,
361 >>()
362 {
363 res.extensions.insert(session);
364 }
365 if let Some(conn) = req.extensions.remove::<Arc<
366 Mutex<salvo_http3::server::Connection<salvo_http3::quinn::Connection, Bytes>>,
367 >>() {
368 res.extensions.insert(conn);
369 }
370 if let Some(stream) = req.extensions.remove::<Arc<
371 salvo_http3::server::RequestStream<
372 salvo_http3::quinn::BidiStream<Bytes>,
373 Bytes,
374 >,
375 >>() {
376 res.extensions.insert(stream);
377 }
378 }
379 res
380 }
381 }
382}
383
384impl<B> HyperService<HyperRequest<B>> for HyperHandler
385where
386 B: Into<ReqBody>,
387{
388 type Response = HyperResponse<ResBody>;
389 type Error = hyper::Error;
390 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
391
392 #[inline]
393 fn call(
394 &self,
395 #[cfg(not(feature = "fix-http1-request-uri"))] req: HyperRequest<B>,
396 #[cfg(feature = "fix-http1-request-uri")] mut req: HyperRequest<B>,
397 ) -> Self::Future {
398 let scheme = req
399 .uri()
400 .scheme()
401 .cloned()
402 .unwrap_or_else(|| self.http_scheme.clone());
403 #[cfg(feature = "fix-http1-request-uri")]
405 if req.uri().scheme().is_none() {
406 if let Some(host) = req
407 .headers()
408 .get(http::header::HOST)
409 .and_then(|host| host.to_str().ok())
410 .and_then(|host| host.parse::<http::uri::Authority>().ok())
411 {
412 let mut uri_parts = std::mem::take(req.uri_mut()).into_parts();
413 uri_parts.scheme = Some(scheme.clone());
414 uri_parts.authority = Some(host);
415 if let Ok(uri) = http::uri::Uri::from_parts(uri_parts) {
416 *req.uri_mut() = uri;
417 }
418 }
419 }
420 let mut request = Request::from_hyper(req, scheme);
421 request.body.set_fusewire(self.fusewire.clone());
422 let response = self.handle(request);
423 Box::pin(async move { Ok(response.await.into_hyper()) })
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use crate::prelude::*;
430 use crate::test::{ResponseExt, TestClient};
431
432 #[tokio::test]
433 async fn test_service() {
434 #[handler]
435 async fn before1(
436 req: &mut Request,
437 depot: &mut Depot,
438 res: &mut Response,
439 ctrl: &mut FlowCtrl,
440 ) {
441 res.render(Text::Plain("before1"));
442 if req.query::<String>("b").unwrap_or_default() == "1" {
443 ctrl.skip_rest();
444 } else {
445 ctrl.call_next(req, depot, res).await;
446 }
447 }
448 #[handler]
449 async fn before2(
450 req: &mut Request,
451 depot: &mut Depot,
452 res: &mut Response,
453 ctrl: &mut FlowCtrl,
454 ) {
455 res.render(Text::Plain("before2"));
456 if req.query::<String>("b").unwrap_or_default() == "2" {
457 ctrl.skip_rest();
458 } else {
459 ctrl.call_next(req, depot, res).await;
460 }
461 }
462 #[handler]
463 async fn before3(
464 req: &mut Request,
465 depot: &mut Depot,
466 res: &mut Response,
467 ctrl: &mut FlowCtrl,
468 ) {
469 res.render(Text::Plain("before3"));
470 if req.query::<String>("b").unwrap_or_default() == "3" {
471 ctrl.skip_rest();
472 } else {
473 ctrl.call_next(req, depot, res).await;
474 }
475 }
476 #[handler]
477 async fn hello() -> Result<&'static str, ()> {
478 Ok("hello")
479 }
480 let router = Router::with_path("level1").hoop(before1).push(
481 Router::with_hoop(before2)
482 .path("level2")
483 .push(Router::with_hoop(before3).path("hello").goal(hello)),
484 );
485 let service = Service::new(router);
486
487 async fn access(service: &Service, b: &str) -> String {
488 TestClient::get(format!("http://127.0.0.1:5801/level1/level2/hello?b={b}"))
489 .send(service)
490 .await
491 .take_string()
492 .await
493 .unwrap()
494 }
495 let content = access(&service, "").await;
496 assert_eq!(content, "before1before2before3hello");
497 let content = access(&service, "1").await;
498 assert_eq!(content, "before1");
499 let content = access(&service, "2").await;
500 assert_eq!(content, "before1before2");
501 let content = access(&service, "3").await;
502 assert_eq!(content, "before1before2before3");
503 }
504
505 #[tokio::test]
506 async fn test_service_405_or_404_error() {
507 #[handler]
508 async fn login() -> &'static str {
509 "login"
510 }
511 #[handler]
512 async fn hello() -> &'static str {
513 "hello"
514 }
515 let router = Router::new()
516 .push(Router::with_path("hello").goal(hello))
517 .push(
518 Router::with_path("login")
519 .post(login)
520 .push(Router::with_path("user").get(login)),
521 );
522 let service = Service::new(router);
523
524 let res = TestClient::get("http://127.0.0.1:5801/hello")
525 .send(&service)
526 .await;
527 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
528 let res = TestClient::put("http://127.0.0.1:5801/hello")
529 .send(&service)
530 .await;
531 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
532
533 let res = TestClient::post("http://127.0.0.1:5801/login")
534 .send(&service)
535 .await;
536 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
537
538 let res = TestClient::get("http://127.0.0.1:5801/login")
539 .send(&service)
540 .await;
541 assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
542
543 let res = TestClient::get("http://127.0.0.1:5801/login2")
544 .send(&service)
545 .await;
546 assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
547
548 let res = TestClient::get("http://127.0.0.1:5801/login/user")
549 .send(&service)
550 .await;
551 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
552
553 let res = TestClient::post("http://127.0.0.1:5801/login/user")
554 .send(&service)
555 .await;
556 assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
557
558 let res = TestClient::post("http://127.0.0.1:5801/login/user1")
559 .send(&service)
560 .await;
561 assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
562 }
563}