1use std::any::TypeId;
4use std::collections::HashMap;
5use std::future::Future;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use http::{HeaderMap, Method, StatusCode, Uri};
11
12use crate::error::{Error, Result};
13use crate::hooks::{
14 ErrorContext, ErrorEvent, PanicEvent, RequestEvent, RequestInfo, ResponseEvent,
15 ValidationErrorEvent,
16};
17use crate::lifespan::{ErasedLifespan, Lifespan, LifespanCell, LifespanContext, ReadyContext};
18use crate::logging::{install as install_logging, Logger, LoggerConfig};
19use crate::middleware::{resolve_duplicates, Middleware, Next, Request};
20use crate::multipart::{AppUploadConfig, UploadConfig};
21use crate::openapi::{AsyncApiProvider, OpenApiProvider};
22use crate::response::{IntoResponse, Response};
23use crate::router::matcher::Matcher;
24use crate::router::{
25 BoxFuture, Route, Router, SharedErrorHook, SharedRequestHook, SharedResponseHook,
26 SharedValidationErrorHook,
27};
28use crate::server::{run_with_shutdown, shutdown_signal, Http1Config, Http2Config};
29#[cfg(feature = "tls")]
30use crate::tls::TlsConfig;
31use crate::state::{AppStateRef, StateMap};
32use crate::ws::{
33 AppWsConfig, WebSocketConfig, WsConnectHook, WsConnectInfo, WsDisconnectHook, WsDisconnectInfo,
34 WsHooks,
35};
36
37type Hook = Box<dyn Fn() -> BoxFuture<'static, Result<()>> + Send + Sync>;
39
40type ReadyHook = Box<dyn Fn(ReadyContext) -> BoxFuture<'static, Result<()>> + Send + Sync>;
42
43type RequestHook = Box<dyn Fn(RequestEvent) -> BoxFuture<'static, ()> + Send + Sync>;
45
46type ResponseHook = Box<dyn Fn(ResponseEvent) -> BoxFuture<'static, ()> + Send + Sync>;
48
49type ErrorHook = Box<dyn Fn(ErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
51
52type ValidationErrorHook =
54 Box<dyn Fn(ValidationErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
55
56type PanicHook = Box<dyn Fn(PanicEvent) -> BoxFuture<'static, ()> + Send + Sync>;
58
59type ExceptionHandlerFn =
61 Box<dyn Fn(Error, ErrorContext) -> BoxFuture<'static, Response> + Send + Sync>;
62
63const REQUEST_ID_HEADER: &str = "x-request-id";
65
66pub struct App {
76 state: StateMap,
77 routers: Vec<Router>,
78 openapi: Option<Box<dyn OpenApiProvider>>,
79 asyncapi: Option<Box<dyn AsyncApiProvider>>,
80 middleware: Vec<Arc<dyn Middleware>>,
81 lifespan: Vec<Box<dyn ErasedLifespan>>,
82 on_startup: Vec<Hook>,
83 on_shutdown: Vec<Hook>,
84 on_ready: Vec<ReadyHook>,
85 on_request: Vec<RequestHook>,
86 on_response: Vec<ResponseHook>,
87 on_error: Vec<ErrorHook>,
88 on_validation_error: Vec<ValidationErrorHook>,
89 on_panic: Vec<PanicHook>,
90 catch_panics: bool,
91 exception_handlers: HashMap<TypeId, ExceptionHandlerFn>,
92 ws_config: Option<WebSocketConfig>,
93 on_ws_connect: Vec<WsConnectHook>,
94 on_ws_disconnect: Vec<WsDisconnectHook>,
95 upload_config: Option<UploadConfig>,
96 logger_config: Option<LoggerConfig>,
97 cache: Option<crate::cache::Cache>,
98 throttler: Option<crate::throttle::Throttler>,
99 max_sse_connections: Option<usize>,
100 max_request_body_size: Option<usize>,
101 reuse_port: bool,
102 idle_timeout: Option<Duration>,
103 header_read_timeout: Option<Duration>,
104 http1_config: Option<Http1Config>,
105 http2_config: Option<Http2Config>,
106 #[cfg(feature = "tls")]
107 tls_config: Option<TlsConfig>,
108}
109
110impl Default for App {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl App {
117 pub fn new() -> Self {
119 Self {
120 state: StateMap::new(),
121 routers: Vec::new(),
122 openapi: None,
123 asyncapi: None,
124 middleware: Vec::new(),
125 lifespan: Vec::new(),
126 on_startup: Vec::new(),
127 on_shutdown: Vec::new(),
128 on_ready: Vec::new(),
129 on_request: Vec::new(),
130 on_response: Vec::new(),
131 on_error: Vec::new(),
132 on_validation_error: Vec::new(),
133 on_panic: Vec::new(),
134 catch_panics: true,
135 exception_handlers: HashMap::new(),
136 ws_config: None,
137 on_ws_connect: Vec::new(),
138 on_ws_disconnect: Vec::new(),
139 upload_config: None,
140 logger_config: None,
141 cache: None,
142 throttler: None,
143 max_sse_connections: None,
144 max_request_body_size: None,
145 reuse_port: false,
146 idle_timeout: None,
147 header_read_timeout: Some(crate::constants::DEFAULT_HEADER_READ_TIMEOUT),
148 http1_config: None,
149 http2_config: None,
150 #[cfg(feature = "tls")]
151 tls_config: None,
152 }
153 }
154
155 pub fn state<S: Send + Sync + 'static>(mut self, state: S) -> Self {
158 self.state.insert(state);
159 self
160 }
161
162 pub fn logger(mut self, config: LoggerConfig) -> Self {
166 self.logger_config = Some(config);
167 self
168 }
169
170 pub fn cache(mut self, cache: crate::cache::Cache) -> Self {
176 self.cache = Some(cache);
177 self
178 }
179
180 pub fn throttle(mut self, throttle: crate::throttle::Throttle) -> Self {
193 self.throttler = Some(crate::throttle::Throttler::new(throttle));
194 self
195 }
196
197 #[cfg(feature = "redis")]
204 pub fn redis(mut self, redis: crate::Redis) -> Self {
205 self.state.insert(redis);
206 self
207 }
208
209 pub fn include_router(mut self, router: Router) -> Self {
211 self.routers.push(router);
212 self
213 }
214
215 pub fn include(self, route: impl FnOnce() -> Route) -> Self {
221 self.include_router(Router::new().route(route()))
222 }
223
224 pub fn openapi<P: OpenApiProvider>(mut self, provider: P) -> Self {
226 self.openapi = Some(Box::new(provider));
227 self
228 }
229
230 pub fn asyncapi<P: AsyncApiProvider>(mut self, provider: P) -> Self {
232 self.asyncapi = Some(Box::new(provider));
233 self
234 }
235
236 pub fn middleware<M: Middleware>(mut self, middleware: M) -> Self {
241 self.middleware.push(Arc::new(middleware));
242 self
243 }
244
245 pub fn lifespan<L: Lifespan>(mut self) -> Self {
252 self.lifespan.push(Box::new(LifespanCell::<L>::new()));
253 self
254 }
255
256 pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
258 where
259 F: Fn() -> Fut + Send + Sync + 'static,
260 Fut: Future<Output = Result<()>> + Send + 'static,
261 {
262 self.on_startup.push(Box::new(move || Box::pin(hook())));
263 self
264 }
265
266 pub fn on_shutdown<F, Fut>(mut self, hook: F) -> Self
268 where
269 F: Fn() -> Fut + Send + Sync + 'static,
270 Fut: Future<Output = Result<()>> + Send + 'static,
271 {
272 self.on_shutdown.push(Box::new(move || Box::pin(hook())));
273 self
274 }
275
276 pub fn on_ready<F, Fut>(mut self, hook: F) -> Self
280 where
281 F: Fn(ReadyContext) -> Fut + Send + Sync + 'static,
282 Fut: Future<Output = Result<()>> + Send + 'static,
283 {
284 self.on_ready.push(Box::new(move |ctx| Box::pin(hook(ctx))));
285 self
286 }
287
288 pub fn on_request<F, Fut>(mut self, hook: F) -> Self
293 where
294 F: Fn(RequestEvent) -> Fut + Send + Sync + 'static,
295 Fut: Future<Output = ()> + Send + 'static,
296 {
297 self.on_request
298 .push(Box::new(move |event| Box::pin(hook(event))));
299 self
300 }
301
302 pub fn on_response<F, Fut>(mut self, hook: F) -> Self
307 where
308 F: Fn(ResponseEvent) -> Fut + Send + Sync + 'static,
309 Fut: Future<Output = ()> + Send + 'static,
310 {
311 self.on_response
312 .push(Box::new(move |event| Box::pin(hook(event))));
313 self
314 }
315
316 pub fn on_error<F, Fut>(mut self, hook: F) -> Self
321 where
322 F: Fn(ErrorEvent) -> Fut + Send + Sync + 'static,
323 Fut: Future<Output = ()> + Send + 'static,
324 {
325 self.on_error
326 .push(Box::new(move |event| Box::pin(hook(event))));
327 self
328 }
329
330 pub fn on_validation_error<F, Fut>(mut self, hook: F) -> Self
333 where
334 F: Fn(ValidationErrorEvent) -> Fut + Send + Sync + 'static,
335 Fut: Future<Output = ()> + Send + 'static,
336 {
337 self.on_validation_error
338 .push(Box::new(move |event| Box::pin(hook(event))));
339 self
340 }
341
342 pub fn on_panic<F, Fut>(mut self, hook: F) -> Self
347 where
348 F: Fn(PanicEvent) -> Fut + Send + Sync + 'static,
349 Fut: Future<Output = ()> + Send + 'static,
350 {
351 self.on_panic
352 .push(Box::new(move |event| Box::pin(hook(event))));
353 self
354 }
355
356 pub fn websocket_config(mut self, config: WebSocketConfig) -> Self {
360 self.ws_config = Some(config);
361 self
362 }
363
364 pub fn max_sse_connections(mut self, limit: usize) -> Self {
371 self.max_sse_connections = Some(limit);
372 self
373 }
374
375 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
383 self.idle_timeout = Some(timeout);
384 self
385 }
386
387 pub fn reuse_port(mut self, enabled: bool) -> Self {
393 self.reuse_port = enabled;
394 self
395 }
396
397 pub fn max_request_body_size(mut self, bytes: usize) -> Self {
405 self.max_request_body_size = Some(bytes);
406 self
407 }
408
409 pub fn header_read_timeout(mut self, timeout: Duration) -> Self {
417 self.header_read_timeout = Some(timeout);
418 self
419 }
420
421 pub fn without_header_read_timeout(mut self) -> Self {
425 self.header_read_timeout = None;
426 self
427 }
428
429 pub fn http1(mut self, config: Http1Config) -> Self {
431 self.http1_config = Some(config);
432 self
433 }
434
435 pub fn http2(mut self, config: Http2Config) -> Self {
439 self.http2_config = Some(config);
440 self
441 }
442
443 #[cfg(feature = "tls")]
449 pub fn tls(mut self, config: TlsConfig) -> Self {
450 self.tls_config = Some(config);
451 self
452 }
453
454 pub fn on_ws_connect<F, Fut>(mut self, hook: F) -> Self
456 where
457 F: Fn(WsConnectInfo) -> Fut + Send + Sync + 'static,
458 Fut: Future<Output = ()> + Send + 'static,
459 {
460 self.on_ws_connect
461 .push(Box::new(move |info| Box::pin(hook(info))));
462 self
463 }
464
465 pub fn on_ws_disconnect<F, Fut>(mut self, hook: F) -> Self
469 where
470 F: Fn(WsDisconnectInfo) -> Fut + Send + Sync + 'static,
471 Fut: Future<Output = ()> + Send + 'static,
472 {
473 self.on_ws_disconnect
474 .push(Box::new(move |info| Box::pin(hook(info))));
475 self
476 }
477
478 pub fn upload_config(mut self, config: UploadConfig) -> Self {
482 self.upload_config = Some(config);
483 self
484 }
485
486 pub fn catch_panics(mut self) -> Self {
493 self.catch_panics = true;
494 self
495 }
496
497 pub fn propagate_panics(mut self) -> Self {
500 self.catch_panics = false;
501 self
502 }
503
504 pub fn exception_handler<E, F, Fut>(mut self, handler: F) -> Self
511 where
512 E: std::error::Error + Send + Sync + 'static,
513 F: Fn(E, ErrorContext) -> Fut + Send + Sync + 'static,
514 Fut: Future<Output = Response> + Send + 'static,
515 {
516 self.exception_handlers.insert(
517 TypeId::of::<E>(),
518 Box::new(move |mut error, ctx| match error.take_source::<E>() {
519 Some(value) => Box::pin(handler(value, ctx)),
520 None => Box::pin(async move { error.into_response() }),
523 }),
524 );
525 self
526 }
527
528 fn validate_lifecycle(&self) -> Result<()> {
530 if !self.lifespan.is_empty()
531 && (!self.on_startup.is_empty() || !self.on_shutdown.is_empty())
532 {
533 return Err(Error::internal(
534 "Cannot use `.lifespan(...)` together with `.on_startup(...)` or `.on_shutdown(...)`.\n\
535 Use either lifespan or event hooks, not both.",
536 )
537 .with_code("LIFECYCLE_CONFLICT"));
538 }
539 Ok(())
540 }
541
542 pub fn build(self) -> Result<AppInner> {
551 self.validate_lifecycle()?;
552
553 let App {
554 mut state,
555 routers,
556 openapi,
557 asyncapi,
558 middleware,
559 on_request,
560 on_response,
561 on_error,
562 on_validation_error,
563 on_panic,
564 catch_panics,
565 exception_handlers,
566 ws_config,
567 on_ws_connect,
568 on_ws_disconnect,
569 upload_config,
570 logger_config,
571 cache,
572 throttler,
573 max_sse_connections,
574 max_request_body_size,
575 idle_timeout,
576 header_read_timeout,
577 http1_config,
578 http2_config,
579 #[cfg(feature = "tls")]
580 tls_config,
581 ..
582 } = self;
583
584 #[cfg(feature = "tls")]
587 let tls_acceptor = match tls_config {
588 Some(config) => Some(config.into_acceptor()?),
589 None => None,
590 };
591 let request_logs = logger_config
593 .as_ref()
594 .map(|config| config.request_logs)
595 .unwrap_or(true);
596
597 let (ws_shutdown, ws_shutdown_rx) = tokio::sync::watch::channel(false);
601 state.insert(crate::ws::WsShutdown(ws_shutdown_rx));
602
603 if let Some(limit) = max_sse_connections {
605 state.insert(crate::sse::SseLimiter::new(limit));
606 }
607
608 if let Some(limit) = max_request_body_size {
610 state.insert(crate::extract::body::AppBodyLimit(limit));
611 }
612
613 if let Some(config) = ws_config {
615 if let Some(max) = config.ip_connection_limit() {
617 state.insert(crate::ws::WsIpLimiter::new(max));
618 }
619 state.insert(AppWsConfig(config));
620 }
621 if !on_ws_connect.is_empty() || !on_ws_disconnect.is_empty() {
623 state.insert(WsHooks {
624 connect: on_ws_connect,
625 disconnect: on_ws_disconnect,
626 });
627 }
628 if let Some(config) = upload_config {
630 state.insert(AppUploadConfig(config));
631 }
632 if let Some(cache) = cache {
634 state.insert(cache);
635 }
636 if let Some(throttler) = throttler {
638 state.insert(throttler);
639 }
640
641 let mut routes = Vec::new();
642 for router in routers {
643 routes.extend(router.into_routes());
644 }
645
646 if let Some(provider) = openapi {
647 let documentation = provider.documentation_routes(&routes);
648 routes.extend(documentation);
649 }
650
651 if let Some(provider) = asyncapi {
652 let documentation = provider.documentation_routes(&routes);
653 routes.extend(documentation);
654 }
655
656 let matcher = Matcher::build(routes)?;
657 let middleware = resolve_duplicates(middleware)?;
658
659 Ok(AppInner {
660 state: Arc::new(state),
661 matcher,
662 middleware: middleware.into(),
663 on_request: on_request.into(),
664 on_response: on_response.into(),
665 on_error: on_error.into(),
666 on_validation_error: on_validation_error.into(),
667 on_panic: on_panic.into(),
668 catch_panics,
669 request_logs,
670 exception_handlers: Arc::new(exception_handlers),
671 ws_shutdown,
672 idle_timeout,
673 header_read_timeout,
674 http1_config,
675 http2_config,
676 #[cfg(feature = "tls")]
677 tls_acceptor,
678 })
679 }
680
681 pub async fn serve(self, addr: impl AsRef<str>) -> Result<()> {
693 self.serve_with_shutdown(addr, shutdown_signal()).await
694 }
695
696 pub async fn serve_with_shutdown<S>(self, addr: impl AsRef<str>, shutdown: S) -> Result<()>
701 where
702 S: std::future::Future<Output = ()>,
703 {
704 let reuse_port = self.reuse_port;
705 let addr = addr.as_ref().to_owned();
706 self.serve_listener(shutdown, move || async move {
707 let listener = crate::server::bind_tcp_listener(&addr, reuse_port)
708 .await
709 .map_err(|error| Error::internal(format!("failed to bind {addr}: {error}")))?;
710 let local = listener.local_addr().map_err(|error| {
711 Error::internal(format!("failed to read local address: {error}"))
712 })?;
713 Ok((listener, local, None))
714 })
715 .await
716 }
717
718 #[cfg(unix)]
723 pub async fn serve_unix(self, path: impl AsRef<std::path::Path>) -> Result<()> {
724 self.serve_unix_with_shutdown(path, shutdown_signal()).await
725 }
726
727 #[cfg(unix)]
730 pub async fn serve_unix_with_shutdown<S>(
731 self,
732 path: impl AsRef<std::path::Path>,
733 shutdown: S,
734 ) -> Result<()>
735 where
736 S: std::future::Future<Output = ()>,
737 {
738 let path = path.as_ref().to_owned();
739 self.serve_listener(shutdown, move || async move {
740 let _ = std::fs::remove_file(&path);
742 let listener = tokio::net::UnixListener::bind(&path).map_err(|error| {
743 Error::internal(format!("failed to bind {}: {error}", path.display()))
744 })?;
745 let placeholder: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder address");
747 Ok((listener, placeholder, Some(path.display().to_string())))
748 })
749 .await
750 }
751
752 async fn serve_listener<S, L, F, Fut>(mut self, shutdown: S, bind: F) -> Result<()>
756 where
757 S: std::future::Future<Output = ()>,
758 L: crate::server::IncomingListener,
759 F: FnOnce() -> Fut,
760 Fut: std::future::Future<Output = Result<(L, SocketAddr, Option<String>)>>,
761 {
762 let logger_config = self.logger_config.clone().unwrap_or_default();
766 let _logger_handle = install_logging(&logger_config);
767 let log = Logger::framework("Tork");
768
769 log.info("Starting Tork application").emit();
770 self.validate_lifecycle().inspect_err(log_boot_error)?;
771 self.run_startup().await.inspect_err(log_boot_error)?;
772
773 let mut lifespan = std::mem::take(&mut self.lifespan);
774 let on_shutdown = std::mem::take(&mut self.on_shutdown);
775 let on_ready = std::mem::take(&mut self.on_ready);
776
777 let app = Arc::new(self.build().inspect_err(log_boot_error)?);
778
779 let explorer = Logger::framework("RouterExplorer");
781 for route in app.matcher().routes() {
782 explorer
783 .info(format!("Mapped {{{} {}}}", route.method(), route.path()))
784 .emit();
785 }
786
787 let (listener, ready_addr, unix_display) = bind().await.inspect_err(log_boot_error)?;
789
790 for hook in &on_ready {
791 hook(ReadyContext::new(ready_addr))
792 .await
793 .inspect_err(log_boot_error)?;
794 }
795
796 let running_on = match unix_display {
797 Some(path) => format!("unix:{path}"),
798 None => {
799 #[cfg(feature = "tls")]
800 let scheme = if app.tls_acceptor().is_some() { "https" } else { "http" };
801 #[cfg(not(feature = "tls"))]
802 let scheme = "http";
803 format!("{scheme}://{ready_addr}")
804 }
805 };
806 Logger::framework("Tork")
807 .info(format!("Application is running on {running_on}"))
808 .emit();
809
810 run_with_shutdown(app, listener, shutdown).await;
811
812 log.info("Shutting down").emit();
813 run_shutdown(&mut lifespan, &on_shutdown).await;
814 Ok(())
815 }
816
817 async fn run_startup(&mut self) -> Result<()> {
821 if !self.lifespan.is_empty() {
822 for index in 0..self.lifespan.len() {
823 let ctx = LifespanContext::new();
824 if let Err(error) = self.lifespan[index].startup(ctx, &mut self.state).await {
825 for started in (0..index).rev() {
826 let _ = self.lifespan[started].shutdown().await;
827 }
828 return Err(error);
829 }
830 }
831 } else {
832 for hook in &self.on_startup {
833 hook().await?;
834 }
835 }
836 Ok(())
837 }
838
839 pub async fn build_test(self) -> Result<TestApp> {
846 self.build_test_with(|_| {}).await
847 }
848
849 pub(crate) async fn build_test_with(
853 mut self,
854 apply: impl FnOnce(&mut StateMap),
855 ) -> Result<TestApp> {
856 self.validate_lifecycle()?;
857 self.run_startup().await?;
858 apply(&mut self.state);
859 let lifespan = std::mem::take(&mut self.lifespan);
860 let on_shutdown = std::mem::take(&mut self.on_shutdown);
861 let inner = Arc::new(self.build()?);
862 Ok(TestApp {
863 inner,
864 lifespan,
865 on_shutdown,
866 })
867 }
868}
869
870fn log_boot_error(error: &Error) {
872 Logger::framework("Tork")
873 .error(error.message().to_owned())
874 .field("code", error.code())
875 .emit();
876}
877
878async fn run_shutdown(lifespan: &mut [Box<dyn ErasedLifespan>], on_shutdown: &[Hook]) {
881 let log = Logger::framework("Lifecycle");
882 if !lifespan.is_empty() {
883 for cell in lifespan.iter_mut().rev() {
884 if let Err(error) = cell.shutdown().await {
885 log.error(format!("shutdown failed: {}", error.message()))
886 .emit();
887 }
888 }
889 } else {
890 for hook in on_shutdown {
891 if let Err(error) = hook().await {
892 log.error(format!("shutdown hook failed: {}", error.message()))
893 .emit();
894 }
895 }
896 }
897}
898
899pub struct TestApp {
906 pub(crate) inner: Arc<AppInner>,
907 pub(crate) lifespan: Vec<Box<dyn ErasedLifespan>>,
908 pub(crate) on_shutdown: Vec<Hook>,
909}
910
911impl TestApp {
912 pub async fn shutdown(mut self) -> Result<()> {
914 run_shutdown(&mut self.lifespan, &self.on_shutdown).await;
915 Ok(())
916 }
917}
918
919pub struct AppInner {
925 state: AppStateRef,
926 matcher: Matcher,
927 middleware: Arc<[Arc<dyn Middleware>]>,
928 on_request: Arc<[RequestHook]>,
929 on_response: Arc<[ResponseHook]>,
930 on_error: Arc<[ErrorHook]>,
931 on_validation_error: Arc<[ValidationErrorHook]>,
932 on_panic: Arc<[PanicHook]>,
933 catch_panics: bool,
934 request_logs: bool,
935 exception_handlers: Arc<HashMap<TypeId, ExceptionHandlerFn>>,
936 ws_shutdown: tokio::sync::watch::Sender<bool>,
939 idle_timeout: Option<Duration>,
941 header_read_timeout: Option<Duration>,
943 http1_config: Option<Http1Config>,
945 http2_config: Option<Http2Config>,
947 #[cfg(feature = "tls")]
949 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
950}
951
952impl AppInner {
953 pub(crate) fn header_read_timeout(&self) -> Option<Duration> {
955 self.header_read_timeout
956 }
957
958 pub(crate) fn idle_timeout(&self) -> Option<Duration> {
960 self.idle_timeout
961 }
962
963 pub(crate) fn http1_config(&self) -> Option<&Http1Config> {
965 self.http1_config.as_ref()
966 }
967
968 pub(crate) fn http2_config(&self) -> Option<&Http2Config> {
970 self.http2_config.as_ref()
971 }
972
973 #[cfg(feature = "tls")]
975 pub(crate) fn tls_acceptor(&self) -> Option<&tokio_rustls::TlsAcceptor> {
976 self.tls_acceptor.as_ref()
977 }
978
979 pub fn state(&self) -> &AppStateRef {
981 &self.state
982 }
983
984 pub fn matcher(&self) -> &Matcher {
986 &self.matcher
987 }
988
989 pub fn begin_ws_shutdown(&self) {
993 let _ = self.ws_shutdown.send(true);
994 }
995
996 pub async fn handle(self: Arc<Self>, request: Request) -> Response {
1004 let info = self
1006 .needs_request_info()
1007 .then(|| request_info(request.method(), request.uri(), request.headers(), None));
1008 let start = (!self.on_response.is_empty()).then(Instant::now);
1009
1010 if let Some(info) = &info {
1011 for hook in self.on_request.iter() {
1012 hook(RequestEvent::new(info.clone())).await;
1013 }
1014 }
1015
1016 let next = Next::new(self.clone(), self.middleware.clone());
1017 let response = match next.run(request).await {
1018 Ok(response) => response,
1019 Err(error) => match &info {
1022 Some(info) => self.render_error(error, info, None).await,
1023 None => error.into_response(),
1024 },
1025 };
1026
1027 if let Some(info) = &info {
1028 let status = response.status();
1029 let elapsed = start.map(|start| start.elapsed()).unwrap_or_default();
1030 for hook in self.on_response.iter() {
1031 hook(ResponseEvent::new(info.clone(), status, elapsed)).await;
1032 }
1033 }
1034
1035 response
1036 }
1037
1038 pub(crate) fn needs_request_info(&self) -> bool {
1043 !self.on_request.is_empty()
1044 || !self.on_response.is_empty()
1045 || !self.on_error.is_empty()
1046 || !self.on_validation_error.is_empty()
1047 || !self.on_panic.is_empty()
1048 || !self.exception_handlers.is_empty()
1049 }
1050
1051 pub(crate) fn catch_panics(&self) -> bool {
1053 self.catch_panics
1054 }
1055
1056 pub(crate) fn request_logs(&self) -> bool {
1058 self.request_logs
1059 }
1060
1061 pub(crate) async fn fire_panic(&self, info: &RequestInfo, message: &str) {
1063 for hook in self.on_panic.iter() {
1064 hook(PanicEvent::new(info.clone(), message.to_owned())).await;
1065 }
1066 }
1067
1068 pub(crate) async fn render_error(
1077 &self,
1078 error: Error,
1079 info: &RequestInfo,
1080 route: Option<&Route>,
1081 ) -> Response {
1082 if error.is_validation() {
1083 for hook in self.on_validation_error.iter() {
1084 hook(ValidationErrorEvent::new(
1085 info.clone(),
1086 error.details().to_vec(),
1087 ))
1088 .await;
1089 }
1090 if let Some(route) = route {
1091 fire_validation_hooks(route.validation_hooks(), info, &error).await;
1092 }
1093 } else {
1094 for hook in self.on_error.iter() {
1095 hook(ErrorEvent::new(
1096 info.clone(),
1097 error.kind().status(),
1098 error.static_code(),
1099 error.message().to_owned(),
1100 ))
1101 .await;
1102 }
1103 if let Some(route) = route {
1104 fire_error_hooks(route.error_hooks(), info, &error).await;
1105 }
1106 }
1107
1108 if let Some(type_id) = error.source_type() {
1109 if let Some(handler) = self.exception_handlers.get(&type_id) {
1110 return handler(error, ErrorContext::new(info.clone())).await;
1111 }
1112 }
1113
1114 error.into_response()
1115 }
1116}
1117
1118pub(crate) async fn fire_request_hooks(hooks: &[SharedRequestHook], info: &RequestInfo) {
1120 for hook in hooks {
1121 hook(RequestEvent::new(info.clone())).await;
1122 }
1123}
1124
1125pub(crate) async fn fire_response_hooks(
1127 hooks: &[SharedResponseHook],
1128 info: &RequestInfo,
1129 status: StatusCode,
1130 elapsed: Duration,
1131) {
1132 for hook in hooks.iter().rev() {
1133 hook(ResponseEvent::new(info.clone(), status, elapsed)).await;
1134 }
1135}
1136
1137async fn fire_error_hooks(hooks: &[SharedErrorHook], info: &RequestInfo, error: &Error) {
1139 for hook in hooks {
1140 hook(ErrorEvent::new(
1141 info.clone(),
1142 error.kind().status(),
1143 error.static_code(),
1144 error.message().to_owned(),
1145 ))
1146 .await;
1147 }
1148}
1149
1150async fn fire_validation_hooks(
1152 hooks: &[SharedValidationErrorHook],
1153 info: &RequestInfo,
1154 error: &Error,
1155) {
1156 for hook in hooks {
1157 hook(ValidationErrorEvent::new(
1158 info.clone(),
1159 error.details().to_vec(),
1160 ))
1161 .await;
1162 }
1163}
1164
1165pub(crate) fn request_info(
1167 method: &Method,
1168 uri: &Uri,
1169 headers: &HeaderMap,
1170 route: Option<String>,
1171) -> RequestInfo {
1172 let request_id = headers
1173 .get(REQUEST_ID_HEADER)
1174 .and_then(|value| value.to_str().ok())
1175 .map(Arc::<str>::from);
1176 RequestInfo::new(
1177 method.clone(),
1178 Arc::from(uri.path()),
1179 route.map(Arc::<str>::from),
1180 request_id,
1181 )
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186 use super::*;
1187 use crate::Resources;
1188 use std::sync::atomic::{AtomicBool, Ordering};
1189
1190 static STARTED: AtomicBool = AtomicBool::new(false);
1191 static STOPPED: AtomicBool = AtomicBool::new(false);
1192
1193 #[derive(Clone)]
1194 struct Boot;
1195
1196 impl Resources for Boot {
1197 fn register(&self, _registry: &mut StateMap) {}
1198 }
1199
1200 impl Lifespan for Boot {
1201 async fn startup(_ctx: LifespanContext) -> Result<Self> {
1202 STARTED.store(true, Ordering::SeqCst);
1203 Ok(Boot)
1204 }
1205
1206 async fn shutdown(self) -> Result<()> {
1207 STOPPED.store(true, Ordering::SeqCst);
1208 Ok(())
1209 }
1210 }
1211
1212 #[tokio::test]
1213 async fn serve_runs_startup_then_shutdown() {
1214 App::new()
1216 .lifespan::<Boot>()
1217 .serve_with_shutdown("127.0.0.1:0", async {})
1218 .await
1219 .unwrap();
1220
1221 assert!(STARTED.load(Ordering::SeqCst), "startup should have run");
1222 assert!(STOPPED.load(Ordering::SeqCst), "shutdown should have run");
1223 }
1224
1225 #[test]
1226 fn lifespan_with_event_hooks_is_a_conflict() {
1227 let error = App::new()
1228 .lifespan::<Boot>()
1229 .on_startup(|| async { Ok(()) })
1230 .build()
1231 .err()
1232 .expect("lifespan plus on_startup should conflict");
1233
1234 assert_eq!(error.code(), "LIFECYCLE_CONFLICT");
1235 assert!(
1236 error
1237 .message()
1238 .contains("Use either lifespan or event hooks"),
1239 "message: {}",
1240 error.message()
1241 );
1242 }
1243
1244 #[test]
1245 fn app_default_equals_new() {
1246 assert!(App::default().build().is_ok());
1248 }
1249
1250 #[test]
1251 fn websocket_config_is_accepted_by_builder() {
1252 let app = App::new()
1253 .websocket_config(crate::ws::WebSocketConfig::new().idle_timeout_secs(5))
1254 .build();
1255 assert!(app.is_ok());
1256 }
1257
1258 #[test]
1259 fn upload_config_is_accepted_by_builder() {
1260 let app = App::new()
1261 .upload_config(crate::multipart::UploadConfig::new().max_file_size(1024))
1262 .build();
1263 assert!(app.is_ok());
1264 }
1265
1266 #[test]
1267 fn on_ws_connect_and_disconnect_hooks_are_accepted_by_builder() {
1268 let app = App::new()
1269 .on_ws_connect(|_info| async {})
1270 .on_ws_disconnect(|_info| async {})
1271 .build();
1272 assert!(app.is_ok());
1273 }
1274
1275 #[test]
1276 fn logger_config_is_accepted_by_builder() {
1277 let app = App::new()
1278 .logger(crate::logging::LoggerConfig::new())
1279 .build();
1280 assert!(app.is_ok());
1281 }
1282
1283 #[test]
1284 fn app_supports_multiple_distinct_state_types() {
1285 struct Greeting(&'static str);
1286 struct Counter(u32);
1287
1288 let app = App::new()
1289 .state(Greeting("hello"))
1290 .state(Counter(42))
1291 .build()
1292 .expect("app builds");
1293
1294 let greeting = app.state().get::<Greeting>().expect("greeting registered");
1296 let counter = app.state().get::<Counter>().expect("counter registered");
1297 assert_eq!(greeting.0, "hello");
1298 assert_eq!(counter.0, 42);
1299 }
1300}
1301
1302#[cfg(test)]
1303mod hook_tests {
1304 use super::*;
1305 use crate::body::box_body;
1306 use crate::extract::RequestContext;
1307 use crate::response::empty;
1308 use crate::router::{HandlerFn, Route};
1309 use crate::ErrorDetail;
1310 use bytes::Bytes;
1311 use http::StatusCode;
1312 use http_body_util::Full;
1313 use std::sync::Mutex;
1314
1315 type Log = Arc<Mutex<Vec<String>>>;
1316
1317 fn log() -> Log {
1318 Arc::new(Mutex::new(Vec::new()))
1319 }
1320
1321 fn entries(log: &Log) -> Vec<String> {
1322 log.lock().unwrap().clone()
1323 }
1324
1325 fn request(method: Method, uri: &str) -> Request {
1326 http::Request::builder()
1327 .method(method)
1328 .uri(uri)
1329 .body(box_body(Full::new(Bytes::new())))
1330 .unwrap()
1331 }
1332
1333 fn failing_route(make: fn() -> Error) -> Router {
1335 let handler: HandlerFn = Arc::new(
1336 move |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1337 Box::pin(async move { Err(make()) })
1338 },
1339 );
1340 Router::new().route(Route::new(Method::GET, "/", handler))
1341 }
1342
1343 fn ok_handler() -> HandlerFn {
1344 Arc::new(
1345 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1346 Box::pin(async { Ok(empty(StatusCode::OK)) })
1347 },
1348 )
1349 }
1350
1351 fn ok_route() -> Router {
1352 Router::new().route(Route::new(Method::GET, "/", ok_handler()))
1353 }
1354
1355 fn panicking_route() -> Router {
1356 let handler: HandlerFn = Arc::new(
1357 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1358 Box::pin(async { panic!("handler boom") })
1359 },
1360 );
1361 Router::new().route(Route::new(Method::GET, "/", handler))
1362 }
1363
1364 #[tokio::test]
1365 async fn on_error_fires_for_a_missing_route() {
1366 let seen = log();
1367 let recorder = seen.clone();
1368 let app = App::new()
1369 .on_error(move |event| {
1370 let recorder = recorder.clone();
1371 let code = event.code().to_owned();
1372 async move { recorder.lock().unwrap().push(code) }
1373 })
1374 .build()
1375 .unwrap();
1376
1377 let response = app.dispatch(request(Method::GET, "/missing")).await;
1378 assert_eq!(response.status(), StatusCode::NOT_FOUND);
1379 assert_eq!(entries(&seen), vec!["NOT_FOUND".to_owned()]);
1380 }
1381
1382 #[tokio::test]
1383 async fn validation_error_fires_only_the_validation_hook() {
1384 fn validation_error() -> Error {
1385 Error::unprocessable("invalid")
1386 .with_code("VALIDATION_ERROR")
1387 .with_details(vec![ErrorDetail::new("name", "TOO_SHORT", "too short")])
1388 }
1389
1390 let errors = log();
1391 let validations = log();
1392 let error_rec = errors.clone();
1393 let validation_rec = validations.clone();
1394
1395 let app = App::new()
1396 .include_router(failing_route(validation_error))
1397 .on_error(move |_event| {
1398 let rec = error_rec.clone();
1399 async move { rec.lock().unwrap().push("error".to_owned()) }
1400 })
1401 .on_validation_error(move |event| {
1402 let rec = validation_rec.clone();
1403 let fields = event.details().len();
1404 async move { rec.lock().unwrap().push(format!("validation:{fields}")) }
1405 })
1406 .build()
1407 .unwrap();
1408
1409 let response = app.dispatch(request(Method::GET, "/")).await;
1410 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
1411 assert_eq!(entries(&validations), vec!["validation:1".to_owned()]);
1412 assert!(
1413 entries(&errors).is_empty(),
1414 "on_error must not fire for validation"
1415 );
1416 }
1417
1418 #[derive(Debug)]
1419 struct SampleError;
1420 impl std::fmt::Display for SampleError {
1421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1422 f.write_str("sample failure")
1423 }
1424 }
1425 impl std::error::Error for SampleError {}
1426
1427 #[tokio::test]
1428 async fn exception_handler_replaces_the_response() {
1429 fn sample_error() -> Error {
1430 Error::internal("wrapped").with_source(SampleError)
1431 }
1432
1433 let app = App::new()
1434 .include_router(failing_route(sample_error))
1435 .exception_handler::<SampleError, _, _>(|error, _ctx| async move {
1436 assert_eq!(error.to_string(), "sample failure");
1438 empty(StatusCode::SERVICE_UNAVAILABLE)
1439 })
1440 .build()
1441 .unwrap();
1442
1443 let response = app.dispatch(request(Method::GET, "/")).await;
1444 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1445 }
1446
1447 #[tokio::test]
1448 async fn request_hooks_run_in_registration_order() {
1449 let seen = log();
1450 let first = seen.clone();
1451 let second = seen.clone();
1452
1453 let app = Arc::new(
1454 App::new()
1455 .include_router(ok_route())
1456 .on_request(move |_event| {
1457 let rec = first.clone();
1458 async move { rec.lock().unwrap().push("first".to_owned()) }
1459 })
1460 .on_request(move |_event| {
1461 let rec = second.clone();
1462 async move { rec.lock().unwrap().push("second".to_owned()) }
1463 })
1464 .build()
1465 .unwrap(),
1466 );
1467
1468 let response = app.handle(request(Method::GET, "/")).await;
1469 assert_eq!(response.status(), StatusCode::OK);
1470 assert_eq!(
1471 entries(&seen),
1472 vec!["first".to_owned(), "second".to_owned()]
1473 );
1474 }
1475
1476 #[tokio::test]
1477 async fn catch_panics_converts_a_panic_into_a_500() {
1478 let seen = log();
1479 let recorder = seen.clone();
1480 let app = App::new()
1481 .include_router(panicking_route())
1482 .catch_panics()
1483 .on_panic(move |event| {
1484 let recorder = recorder.clone();
1485 let message = event.message().to_owned();
1486 async move { recorder.lock().unwrap().push(message) }
1487 })
1488 .build()
1489 .unwrap();
1490
1491 let response = app.dispatch(request(Method::GET, "/")).await;
1492 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
1493 assert_eq!(entries(&seen), vec!["handler boom".to_owned()]);
1494 }
1495
1496 #[tokio::test]
1497 #[should_panic(expected = "handler boom")]
1498 async fn without_catch_panics_a_panic_propagates() {
1499 let app = App::new()
1500 .propagate_panics()
1501 .include_router(panicking_route())
1502 .build()
1503 .unwrap();
1504 let _ = app.dispatch(request(Method::GET, "/")).await;
1505 }
1506
1507 #[tokio::test]
1508 async fn scoped_on_request_fires_only_for_its_router() {
1509 let seen = log();
1510 let recorder = seen.clone();
1511 let scoped = Router::new()
1512 .route(Route::new(Method::GET, "/a", ok_handler()))
1513 .on_request(move |_event| {
1514 let recorder = recorder.clone();
1515 async move { recorder.lock().unwrap().push("a".to_owned()) }
1516 });
1517 let plain = Router::new().route(Route::new(Method::GET, "/b", ok_handler()));
1518 let app = App::new()
1519 .include_router(scoped)
1520 .include_router(plain)
1521 .build()
1522 .unwrap();
1523
1524 let _ = app.dispatch(request(Method::GET, "/a")).await;
1525 let _ = app.dispatch(request(Method::GET, "/b")).await;
1526 assert_eq!(entries(&seen), vec!["a".to_owned()]);
1527 }
1528
1529 #[tokio::test]
1530 async fn scoped_on_error_runs_after_the_global_hook() {
1531 let seen = log();
1532 let global = seen.clone();
1533 let scoped = seen.clone();
1534
1535 let router = failing_route(|| Error::not_found("missing")).on_error(move |_event| {
1536 let scoped = scoped.clone();
1537 async move { scoped.lock().unwrap().push("scoped".to_owned()) }
1538 });
1539 let app = App::new()
1540 .on_error(move |_event| {
1541 let global = global.clone();
1542 async move { global.lock().unwrap().push("global".to_owned()) }
1543 })
1544 .include_router(router)
1545 .build()
1546 .unwrap();
1547
1548 let _ = app.dispatch(request(Method::GET, "/")).await;
1549 assert_eq!(
1550 entries(&seen),
1551 vec!["global".to_owned(), "scoped".to_owned()]
1552 );
1553 }
1554
1555 #[tokio::test]
1556 async fn scoped_on_response_hooks_fire_in_reverse() {
1557 let seen = log();
1558 let first = seen.clone();
1559 let second = seen.clone();
1560 let router = Router::new()
1561 .route(Route::new(Method::GET, "/", ok_handler()))
1562 .on_response(move |_event| {
1563 let first = first.clone();
1564 async move { first.lock().unwrap().push("first".to_owned()) }
1565 })
1566 .on_response(move |_event| {
1567 let second = second.clone();
1568 async move { second.lock().unwrap().push("second".to_owned()) }
1569 });
1570 let app = App::new().include_router(router).build().unwrap();
1571
1572 let _ = app.dispatch(request(Method::GET, "/")).await;
1573 assert_eq!(
1574 entries(&seen),
1575 vec!["second".to_owned(), "first".to_owned()]
1576 );
1577 }
1578}