Skip to main content

tork_core/
app.rs

1//! The application builder and its finalized, request-handling core.
2
3use std::any::TypeId;
4use std::collections::HashMap;
5use std::future::Future;
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use http::{HeaderMap, Method, StatusCode, Uri};
11
12use crate::error::{Error, Result};
13use crate::hooks::{
14    ErrorContext, ErrorEvent, PanicEvent, RequestEvent, RequestInfo, ResponseEvent,
15    ValidationErrorEvent,
16};
17use crate::lifespan::{ErasedLifespan, Lifespan, LifespanCell, LifespanContext, ReadyContext};
18use crate::logging::{install as install_logging, Logger, LoggerConfig};
19use crate::middleware::{resolve_duplicates, Middleware, Next, Request};
20use crate::multipart::{AppUploadConfig, UploadConfig};
21use crate::openapi::{AsyncApiProvider, OpenApiProvider};
22use crate::response::{IntoResponse, Response};
23use crate::router::matcher::Matcher;
24use crate::router::{
25    BoxFuture, Route, Router, SharedErrorHook, SharedRequestHook, SharedResponseHook,
26    SharedValidationErrorHook,
27};
28use crate::server::{run_with_shutdown, shutdown_signal, Http1Config, Http2Config};
29#[cfg(feature = "tls")]
30use crate::tls::TlsConfig;
31use crate::state::{AppStateRef, StateMap};
32use crate::ws::{
33    AppWsConfig, WebSocketConfig, WsConnectHook, WsConnectInfo, WsDisconnectHook, WsDisconnectInfo,
34    WsHooks,
35};
36
37/// A startup or shutdown event hook.
38type Hook = Box<dyn Fn() -> BoxFuture<'static, Result<()>> + Send + Sync>;
39
40/// A post-bind readiness hook.
41type ReadyHook = Box<dyn Fn(ReadyContext) -> BoxFuture<'static, Result<()>> + Send + Sync>;
42
43/// An observe-only hook fired when a request arrives.
44type RequestHook = Box<dyn Fn(RequestEvent) -> BoxFuture<'static, ()> + Send + Sync>;
45
46/// An observe-only hook fired when a response is ready.
47type ResponseHook = Box<dyn Fn(ResponseEvent) -> BoxFuture<'static, ()> + Send + Sync>;
48
49/// An observe-only hook fired for a non-validation error.
50type ErrorHook = Box<dyn Fn(ErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
51
52/// An observe-only hook fired for a request-body validation failure.
53type ValidationErrorHook =
54    Box<dyn Fn(ValidationErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
55
56/// An observe-only hook fired when a handler panic is caught.
57type PanicHook = Box<dyn Fn(PanicEvent) -> BoxFuture<'static, ()> + Send + Sync>;
58
59/// Maps a recovered typed error into a response.
60type ExceptionHandlerFn =
61    Box<dyn Fn(Error, ErrorContext) -> BoxFuture<'static, Response> + Send + Sync>;
62
63/// Header consulted to correlate hook events with a request identifier.
64const REQUEST_ID_HEADER: &str = "x-request-id";
65
66/// The application builder.
67///
68/// `App` collects application state, routers, and optional OpenAPI configuration,
69/// then either finalizes into an [`AppInner`] via [`App::build`] or starts
70/// serving via [`App::serve`](crate::App::serve).
71///
72/// `App` is deliberately not generic over its state type: state is stored in a
73/// type-erased [`StateMap`], which is what lets router modules be defined without
74/// any knowledge of the concrete state type.
75pub struct App {
76    state: StateMap,
77    routers: Vec<Router>,
78    openapi: Option<Box<dyn OpenApiProvider>>,
79    asyncapi: Option<Box<dyn AsyncApiProvider>>,
80    middleware: Vec<Arc<dyn Middleware>>,
81    lifespan: Vec<Box<dyn ErasedLifespan>>,
82    on_startup: Vec<Hook>,
83    on_shutdown: Vec<Hook>,
84    on_ready: Vec<ReadyHook>,
85    on_request: Vec<RequestHook>,
86    on_response: Vec<ResponseHook>,
87    on_error: Vec<ErrorHook>,
88    on_validation_error: Vec<ValidationErrorHook>,
89    on_panic: Vec<PanicHook>,
90    catch_panics: bool,
91    exception_handlers: HashMap<TypeId, ExceptionHandlerFn>,
92    ws_config: Option<WebSocketConfig>,
93    on_ws_connect: Vec<WsConnectHook>,
94    on_ws_disconnect: Vec<WsDisconnectHook>,
95    upload_config: Option<UploadConfig>,
96    logger_config: Option<LoggerConfig>,
97    cache: Option<crate::cache::Cache>,
98    throttler: Option<crate::throttle::Throttler>,
99    max_sse_connections: Option<usize>,
100    max_request_body_size: Option<usize>,
101    reuse_port: bool,
102    idle_timeout: Option<Duration>,
103    header_read_timeout: Option<Duration>,
104    http1_config: Option<Http1Config>,
105    http2_config: Option<Http2Config>,
106    #[cfg(feature = "tls")]
107    tls_config: Option<TlsConfig>,
108}
109
110impl Default for App {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl App {
117    /// Creates an empty application.
118    pub fn new() -> Self {
119        Self {
120            state: StateMap::new(),
121            routers: Vec::new(),
122            openapi: None,
123            asyncapi: None,
124            middleware: Vec::new(),
125            lifespan: Vec::new(),
126            on_startup: Vec::new(),
127            on_shutdown: Vec::new(),
128            on_ready: Vec::new(),
129            on_request: Vec::new(),
130            on_response: Vec::new(),
131            on_error: Vec::new(),
132            on_validation_error: Vec::new(),
133            on_panic: Vec::new(),
134            catch_panics: true,
135            exception_handlers: HashMap::new(),
136            ws_config: None,
137            on_ws_connect: Vec::new(),
138            on_ws_disconnect: Vec::new(),
139            upload_config: None,
140            logger_config: None,
141            cache: None,
142            throttler: None,
143            max_sse_connections: None,
144            max_request_body_size: None,
145            reuse_port: false,
146            idle_timeout: None,
147            header_read_timeout: Some(crate::constants::DEFAULT_HEADER_READ_TIMEOUT),
148            http1_config: None,
149            http2_config: None,
150            #[cfg(feature = "tls")]
151            tls_config: None,
152        }
153    }
154
155    /// Registers an application state value, retrievable via the
156    /// [`State`](crate::State) extractor.
157    pub fn state<S: Send + Sync + 'static>(mut self, state: S) -> Self {
158        self.state.insert(state);
159        self
160    }
161
162    /// Configures logging. Without this, logging is on by default with sensible
163    /// settings (a developer console on a terminal, JSON otherwise; level from
164    /// `RUST_LOG` or `info`).
165    pub fn logger(mut self, config: LoggerConfig) -> Self {
166        self.logger_config = Some(config);
167        self
168    }
169
170    /// Enables caching, making the [`Cache`](crate::Cache) injectable into handlers
171    /// and services.
172    ///
173    /// Pass a configured cache, for example `Cache::in_memory()` for the default
174    /// in-memory store. Without this call, injecting a `Cache` fails.
175    pub fn cache(mut self, cache: crate::cache::Cache) -> Self {
176        self.cache = Some(cache);
177        self
178    }
179
180    /// Enables rate limiting, defining the policies routes can apply with the
181    /// `throttle` attribute and (optionally) a global default.
182    ///
183    /// ```no_run
184    /// # use tork_core::{App, Throttle};
185    /// App::new().throttle(
186    ///     Throttle::new()
187    ///         .policy("default", 100, 60)
188    ///         .policy("strict", 5, 60)
189    ///         .default("default"),
190    /// );
191    /// ```
192    pub fn throttle(mut self, throttle: crate::throttle::Throttle) -> Self {
193        self.throttler = Some(crate::throttle::Throttler::new(throttle));
194        self
195    }
196
197    /// Registers a Redis connection, making [`Redis`](crate::Redis) injectable into
198    /// handlers and services for raw commands, Lua scripts, idempotency, and so on.
199    ///
200    /// Build it with `Redis::connect(url).await?`. Share one connection with the
201    /// cache by passing the same handle to [`Cache::from_redis`](crate::Cache::from_redis).
202    /// Available with the `redis` feature.
203    #[cfg(feature = "redis")]
204    pub fn redis(mut self, redis: crate::Redis) -> Self {
205        self.state.insert(redis);
206        self
207    }
208
209    /// Mounts a router's routes on the application.
210    pub fn include_router(mut self, router: Router) -> Self {
211        self.routers.push(router);
212        self
213    }
214
215    /// Mounts a single route, given the route factory generated for a handler.
216    ///
217    /// A `#[get]` / `#[post]` / `#[sse]` / `#[websocket]` handler named `handler`
218    /// generates a `handler()` route factory, so `App::new().include(handler)`
219    /// registers it directly without building a `Router`.
220    pub fn include(self, route: impl FnOnce() -> Route) -> Self {
221        self.include_router(Router::new().route(route()))
222    }
223
224    /// Configures OpenAPI document generation and the documentation UI.
225    pub fn openapi<P: OpenApiProvider>(mut self, provider: P) -> Self {
226        self.openapi = Some(Box::new(provider));
227        self
228    }
229
230    /// Configures AsyncAPI document generation for the SSE/WebSocket channels.
231    pub fn asyncapi<P: AsyncApiProvider>(mut self, provider: P) -> Self {
232        self.asyncapi = Some(Box::new(provider));
233        self
234    }
235
236    /// Registers a middleware layer.
237    ///
238    /// Layers run in registration order, outermost first. Some middlewares may
239    /// only be registered once; see [`DuplicatePolicy`](crate::DuplicatePolicy).
240    pub fn middleware<M: Middleware>(mut self, middleware: M) -> Self {
241        self.middleware.push(Arc::new(middleware));
242        self
243    }
244
245    /// Registers a lifespan: a resource container with typed startup/shutdown.
246    ///
247    /// Lifespans start in registration order and stop in reverse order. Their
248    /// resources are registered for injection. Using a lifespan together with
249    /// [`on_startup`](App::on_startup) or [`on_shutdown`](App::on_shutdown) is a
250    /// configuration error.
251    pub fn lifespan<L: Lifespan>(mut self) -> Self {
252        self.lifespan.push(Box::new(LifespanCell::<L>::new()));
253        self
254    }
255
256    /// Registers a startup hook (for apps that do not use a lifespan).
257    pub fn on_startup<F, Fut>(mut self, hook: F) -> Self
258    where
259        F: Fn() -> Fut + Send + Sync + 'static,
260        Fut: Future<Output = Result<()>> + Send + 'static,
261    {
262        self.on_startup.push(Box::new(move || Box::pin(hook())));
263        self
264    }
265
266    /// Registers a shutdown hook (for apps that do not use a lifespan).
267    pub fn on_shutdown<F, Fut>(mut self, hook: F) -> Self
268    where
269        F: Fn() -> Fut + Send + Sync + 'static,
270        Fut: Future<Output = Result<()>> + Send + 'static,
271    {
272        self.on_shutdown.push(Box::new(move || Box::pin(hook())));
273        self
274    }
275
276    /// Registers a hook that runs once the listener has bound.
277    ///
278    /// Allowed in both lifespan and event-hook modes.
279    pub fn on_ready<F, Fut>(mut self, hook: F) -> Self
280    where
281        F: Fn(ReadyContext) -> Fut + Send + Sync + 'static,
282        Fut: Future<Output = Result<()>> + Send + 'static,
283    {
284        self.on_ready.push(Box::new(move |ctx| Box::pin(hook(ctx))));
285        self
286    }
287
288    /// Registers an observe-only hook that runs when a request arrives.
289    ///
290    /// Hooks run in registration order, before the middleware chain, and cannot
291    /// alter the response. Use them for logging, metrics, or tracing.
292    pub fn on_request<F, Fut>(mut self, hook: F) -> Self
293    where
294        F: Fn(RequestEvent) -> Fut + Send + Sync + 'static,
295        Fut: Future<Output = ()> + Send + 'static,
296    {
297        self.on_request
298            .push(Box::new(move |event| Box::pin(hook(event))));
299        self
300    }
301
302    /// Registers an observe-only hook that runs once a response is ready.
303    ///
304    /// Hooks run in registration order, after the middleware chain, and observe
305    /// the final status and elapsed time.
306    pub fn on_response<F, Fut>(mut self, hook: F) -> Self
307    where
308        F: Fn(ResponseEvent) -> Fut + Send + Sync + 'static,
309        Fut: Future<Output = ()> + Send + 'static,
310    {
311        self.on_response
312            .push(Box::new(move |event| Box::pin(hook(event))));
313        self
314    }
315
316    /// Registers an observe-only hook that runs for a non-validation error.
317    ///
318    /// Validation failures (`422`) go to [`on_validation_error`](App::on_validation_error)
319    /// instead; every other error fires this hook.
320    pub fn on_error<F, Fut>(mut self, hook: F) -> Self
321    where
322        F: Fn(ErrorEvent) -> Fut + Send + Sync + 'static,
323        Fut: Future<Output = ()> + Send + 'static,
324    {
325        self.on_error
326            .push(Box::new(move |event| Box::pin(hook(event))));
327        self
328    }
329
330    /// Registers an observe-only hook that runs for a request-body validation
331    /// failure (`422`).
332    pub fn on_validation_error<F, Fut>(mut self, hook: F) -> Self
333    where
334        F: Fn(ValidationErrorEvent) -> Fut + Send + Sync + 'static,
335        Fut: Future<Output = ()> + Send + 'static,
336    {
337        self.on_validation_error
338            .push(Box::new(move |event| Box::pin(hook(event))));
339        self
340    }
341
342    /// Registers an observe-only hook that runs when a handler panic is caught.
343    ///
344    /// Has no effect unless the panic boundary is enabled with
345    /// [`catch_panics`](App::catch_panics).
346    pub fn on_panic<F, Fut>(mut self, hook: F) -> Self
347    where
348        F: Fn(PanicEvent) -> Fut + Send + Sync + 'static,
349        Fut: Future<Output = ()> + Send + 'static,
350    {
351        self.on_panic
352            .push(Box::new(move |event| Box::pin(hook(event))));
353        self
354    }
355
356    /// Sets default WebSocket limits and timeouts for every `#[websocket]` route.
357    ///
358    /// A route's own `#[websocket(...)]` limits override these defaults.
359    pub fn websocket_config(mut self, config: WebSocketConfig) -> Self {
360        self.ws_config = Some(config);
361        self
362    }
363
364    /// Caps the number of concurrent Server-Sent Events streams the app will serve.
365    ///
366    /// Each SSE connection holds a pinned stream and timers for its whole lifetime
367    /// (often hours), so an unbounded count can exhaust memory. Once the cap is
368    /// reached, further `#[sse]` requests are rejected with `503 Service
369    /// Unavailable` until a stream ends. With no cap set, SSE streams are unbounded.
370    pub fn max_sse_connections(mut self, limit: usize) -> Self {
371        self.max_sse_connections = Some(limit);
372        self
373    }
374
375    /// Closes a connection after this long with no read or write activity.
376    ///
377    /// Bounds slow or abandoned connections (and zombie keep-alive sockets). It is
378    /// off by default, because a legitimately idle long-lived connection — an open
379    /// WebSocket or a quiet Server-Sent Events stream — is normal; enable it only
380    /// when your routes do not rely on long-lived idle connections, or set it
381    /// comfortably above your SSE heartbeat interval.
382    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
383        self.idle_timeout = Some(timeout);
384        self
385    }
386
387    /// Binds the listening socket with `SO_REUSEPORT` (Unix), so several processes
388    /// or instances can listen on the same address and the kernel load-balances new
389    /// connections across them.
390    ///
391    /// Has no effect on non-Unix platforms. Off by default.
392    pub fn reuse_port(mut self, enabled: bool) -> Self {
393        self.reuse_port = enabled;
394        self
395    }
396
397    /// Sets the maximum size, in bytes, of a buffered request body.
398    ///
399    /// Applies to the JSON, `Valid<T>`, and urlencoded `Form<T>` body extractors,
400    /// which enforce the cap incrementally as the body arrives (an oversized body is
401    /// rejected with `400` before it is fully buffered). Multipart uploads have their
402    /// own [`UploadConfig`] limits. Defaults to
403    /// [`MAX_BODY_BYTES`](crate::constants::MAX_BODY_BYTES) (2 MiB).
404    pub fn max_request_body_size(mut self, bytes: usize) -> Self {
405        self.max_request_body_size = Some(bytes);
406        self
407    }
408
409    /// Sets how long a client may take to send the complete request head (request
410    /// line + headers) after its connection is accepted.
411    ///
412    /// This bounds slowloris-style attacks where a client opens a connection and
413    /// dribbles header bytes to tie up a worker. Defaults to
414    /// [`DEFAULT_HEADER_READ_TIMEOUT`](crate::constants::DEFAULT_HEADER_READ_TIMEOUT)
415    /// (30s); pass a longer duration to relax it. Applies to HTTP/1 connections.
416    pub fn header_read_timeout(mut self, timeout: Duration) -> Self {
417        self.header_read_timeout = Some(timeout);
418        self
419    }
420
421    /// Removes the request-head read deadline, letting a client take unlimited time
422    /// to send its headers. Only do this behind a trusted proxy that already bounds
423    /// slow clients.
424    pub fn without_header_read_timeout(mut self) -> Self {
425        self.header_read_timeout = None;
426        self
427    }
428
429    /// Tunes HTTP/1 behavior (keep-alive, header count) for every connection.
430    pub fn http1(mut self, config: Http1Config) -> Self {
431        self.http1_config = Some(config);
432        self
433    }
434
435    /// Tunes HTTP/2 behavior (stream limits, keep-alive, flow control) for every
436    /// connection. HTTP/2 is served automatically over a plaintext upgrade or over
437    /// TLS via ALPN.
438    pub fn http2(mut self, config: Http2Config) -> Self {
439        self.http2_config = Some(config);
440        self
441    }
442
443    /// Terminates TLS for the server using the given certificate configuration.
444    ///
445    /// With TLS set, [`serve`](App::serve) negotiates HTTPS (and HTTP/2 via ALPN by
446    /// default). A malformed certificate or key fails fast at boot. Requires the
447    /// `tls` feature.
448    #[cfg(feature = "tls")]
449    pub fn tls(mut self, config: TlsConfig) -> Self {
450        self.tls_config = Some(config);
451        self
452    }
453
454    /// Registers an observe-only hook that runs when a WebSocket opens.
455    pub fn on_ws_connect<F, Fut>(mut self, hook: F) -> Self
456    where
457        F: Fn(WsConnectInfo) -> Fut + Send + Sync + 'static,
458        Fut: Future<Output = ()> + Send + 'static,
459    {
460        self.on_ws_connect
461            .push(Box::new(move |info| Box::pin(hook(info))));
462        self
463    }
464
465    /// Registers an observe-only hook that runs when a WebSocket closes.
466    ///
467    /// The event carries how long the connection was open and the close code.
468    pub fn on_ws_disconnect<F, Fut>(mut self, hook: F) -> Self
469    where
470        F: Fn(WsDisconnectInfo) -> Fut + Send + Sync + 'static,
471        Fut: Future<Output = ()> + Send + 'static,
472    {
473        self.on_ws_disconnect
474            .push(Box::new(move |info| Box::pin(hook(info))));
475        self
476    }
477
478    /// Sets default multipart upload limits for every form/file route.
479    ///
480    /// A route's own `#[post("/p", upload(...))]` limits override these defaults.
481    pub fn upload_config(mut self, config: UploadConfig) -> Self {
482        self.upload_config = Some(config);
483        self
484    }
485
486    /// Enables the panic boundary: a panic in a handler is caught and turned into
487    /// a `500` response instead of dropping the connection.
488    ///
489    /// Enabled by default. A caught panic fires the [`on_panic`](App::on_panic)
490    /// hooks. The boundary has no effect when the process is built with
491    /// `panic = "abort"`.
492    pub fn catch_panics(mut self) -> Self {
493        self.catch_panics = true;
494        self
495    }
496
497    /// Disables the panic boundary: handler panics propagate and tear down the
498    /// request task instead of becoming a `500` response.
499    pub fn propagate_panics(mut self) -> Self {
500        self.catch_panics = false;
501        self
502    }
503
504    /// Registers a handler that maps a typed error `E` into a response.
505    ///
506    /// When an error carries a source of type `E` (for example one produced by a
507    /// `#[derive(AppError)]` type via `?`), the registered handler receives the
508    /// recovered value and produces the response. Registering a handler for a type
509    /// again replaces the previous one.
510    pub fn exception_handler<E, F, Fut>(mut self, handler: F) -> Self
511    where
512        E: std::error::Error + Send + Sync + 'static,
513        F: Fn(E, ErrorContext) -> Fut + Send + Sync + 'static,
514        Fut: Future<Output = Response> + Send + 'static,
515    {
516        self.exception_handlers.insert(
517            TypeId::of::<E>(),
518            Box::new(move |mut error, ctx| match error.take_source::<E>() {
519                Some(value) => Box::pin(handler(value, ctx)),
520                // The source matched by type id but could not be recovered; fall
521                // back to the default rendering rather than dropping the error.
522                None => Box::pin(async move { error.into_response() }),
523            }),
524        );
525        self
526    }
527
528    /// Rejects mixing a lifespan with startup/shutdown event hooks.
529    fn validate_lifecycle(&self) -> Result<()> {
530        if !self.lifespan.is_empty()
531            && (!self.on_startup.is_empty() || !self.on_shutdown.is_empty())
532        {
533            return Err(Error::internal(
534                "Cannot use `.lifespan(...)` together with `.on_startup(...)` or `.on_shutdown(...)`.\n\
535                 Use either lifespan or event hooks, not both.",
536            )
537            .with_code("LIFECYCLE_CONFLICT"));
538        }
539        Ok(())
540    }
541
542    /// Finalizes the application into its request-handling core.
543    ///
544    /// Routers are flattened into a single route table, OpenAPI documentation
545    /// routes (if configured) are appended, and a [`Matcher`] is compiled.
546    ///
547    /// # Errors
548    ///
549    /// Returns an error if the route table contains an invalid or duplicate path.
550    pub fn build(self) -> Result<AppInner> {
551        self.validate_lifecycle()?;
552
553        let App {
554            mut state,
555            routers,
556            openapi,
557            asyncapi,
558            middleware,
559            on_request,
560            on_response,
561            on_error,
562            on_validation_error,
563            on_panic,
564            catch_panics,
565            exception_handlers,
566            ws_config,
567            on_ws_connect,
568            on_ws_disconnect,
569            upload_config,
570            logger_config,
571            cache,
572            throttler,
573            max_sse_connections,
574            max_request_body_size,
575            idle_timeout,
576            header_read_timeout,
577            http1_config,
578            http2_config,
579            #[cfg(feature = "tls")]
580            tls_config,
581            ..
582        } = self;
583
584        // Build the TLS acceptor up front so a malformed certificate or key fails
585        // fast at boot rather than on the first connection.
586        #[cfg(feature = "tls")]
587        let tls_acceptor = match tls_config {
588            Some(config) => Some(config.into_acceptor()?),
589            None => None,
590        };
591        // The automatic HTTP request log is on unless the logger config disables it.
592        let request_logs = logger_config
593            .as_ref()
594            .map(|config| config.request_logs)
595            .unwrap_or(true);
596
597        // A shutdown channel lets in-flight WebSocket connections close cleanly
598        // when the server stops. The receiver lives in app state; the sender is
599        // held on `AppInner` and flipped by the server on shutdown.
600        let (ws_shutdown, ws_shutdown_rx) = tokio::sync::watch::channel(false);
601        state.insert(crate::ws::WsShutdown(ws_shutdown_rx));
602
603        // A concurrent-connection cap for SSE streams, when configured.
604        if let Some(limit) = max_sse_connections {
605            state.insert(crate::sse::SseLimiter::new(limit));
606        }
607
608        // The request-body size cap read by the body extractors, when configured.
609        if let Some(limit) = max_request_body_size {
610            state.insert(crate::extract::body::AppBodyLimit(limit));
611        }
612
613        // Make the default WebSocket config available to websocket handlers.
614        if let Some(config) = ws_config {
615            // A per-IP connection cap needs one shared counter for the whole app.
616            if let Some(max) = config.ip_connection_limit() {
617                state.insert(crate::ws::WsIpLimiter::new(max));
618            }
619            state.insert(AppWsConfig(config));
620        }
621        // Make the WebSocket lifecycle hooks available to websocket handlers.
622        if !on_ws_connect.is_empty() || !on_ws_disconnect.is_empty() {
623            state.insert(WsHooks {
624                connect: on_ws_connect,
625                disconnect: on_ws_disconnect,
626            });
627        }
628        // Make the default upload config available to form/file handlers.
629        if let Some(config) = upload_config {
630            state.insert(AppUploadConfig(config));
631        }
632        // Make the cache available to handlers and services that inject it.
633        if let Some(cache) = cache {
634            state.insert(cache);
635        }
636        // Make the throttle engine available to generated route checks.
637        if let Some(throttler) = throttler {
638            state.insert(throttler);
639        }
640
641        let mut routes = Vec::new();
642        for router in routers {
643            routes.extend(router.into_routes());
644        }
645
646        if let Some(provider) = openapi {
647            let documentation = provider.documentation_routes(&routes);
648            routes.extend(documentation);
649        }
650
651        if let Some(provider) = asyncapi {
652            let documentation = provider.documentation_routes(&routes);
653            routes.extend(documentation);
654        }
655
656        let matcher = Matcher::build(routes)?;
657        let middleware = resolve_duplicates(middleware)?;
658
659        Ok(AppInner {
660            state: Arc::new(state),
661            matcher,
662            middleware: middleware.into(),
663            on_request: on_request.into(),
664            on_response: on_response.into(),
665            on_error: on_error.into(),
666            on_validation_error: on_validation_error.into(),
667            on_panic: on_panic.into(),
668            catch_panics,
669            request_logs,
670            exception_handlers: Arc::new(exception_handlers),
671            ws_shutdown,
672            idle_timeout,
673            header_read_timeout,
674            http1_config,
675            http2_config,
676            #[cfg(feature = "tls")]
677            tls_acceptor,
678        })
679    }
680
681    /// Runs the application lifecycle and serves on `addr` until a shutdown signal.
682    ///
683    /// Listens for `SIGINT` (Ctrl-C) and, on Unix, `SIGTERM`. The lifecycle runs
684    /// in order: startup (lifespans or `on_startup` hooks), bind, `on_ready`
685    /// hooks, the accept loop, drain, then shutdown (lifespans in reverse, or
686    /// `on_shutdown` hooks).
687    ///
688    /// # Errors
689    ///
690    /// Returns an error for a lifecycle misconfiguration, a failed startup, or a
691    /// bind failure.
692    pub async fn serve(self, addr: impl AsRef<str>) -> Result<()> {
693        self.serve_with_shutdown(addr, shutdown_signal()).await
694    }
695
696    /// Runs the lifecycle, stopping the accept loop when `shutdown` resolves.
697    ///
698    /// Like [`serve`](App::serve) but driven by a caller-supplied future instead
699    /// of `SIGINT`/`SIGTERM`, for custom graceful shutdown (and for tests).
700    pub async fn serve_with_shutdown<S>(self, addr: impl AsRef<str>, shutdown: S) -> Result<()>
701    where
702        S: std::future::Future<Output = ()>,
703    {
704        let reuse_port = self.reuse_port;
705        let addr = addr.as_ref().to_owned();
706        self.serve_listener(shutdown, move || async move {
707            let listener = crate::server::bind_tcp_listener(&addr, reuse_port)
708                .await
709                .map_err(|error| Error::internal(format!("failed to bind {addr}: {error}")))?;
710            let local = listener.local_addr().map_err(|error| {
711                Error::internal(format!("failed to read local address: {error}"))
712            })?;
713            Ok((listener, local, None))
714        })
715        .await
716    }
717
718    /// Serves over a Unix-domain socket at `path`, until a `SIGINT`/`SIGTERM`.
719    ///
720    /// Useful behind a reverse proxy on the same host. A stale socket file at
721    /// `path` is removed before binding. Unix only.
722    #[cfg(unix)]
723    pub async fn serve_unix(self, path: impl AsRef<std::path::Path>) -> Result<()> {
724        self.serve_unix_with_shutdown(path, shutdown_signal()).await
725    }
726
727    /// Serves over a Unix-domain socket at `path`, stopping when `shutdown`
728    /// resolves. Unix only.
729    #[cfg(unix)]
730    pub async fn serve_unix_with_shutdown<S>(
731        self,
732        path: impl AsRef<std::path::Path>,
733        shutdown: S,
734    ) -> Result<()>
735    where
736        S: std::future::Future<Output = ()>,
737    {
738        let path = path.as_ref().to_owned();
739        self.serve_listener(shutdown, move || async move {
740            // Remove a stale socket file so re-binding the path succeeds.
741            let _ = std::fs::remove_file(&path);
742            let listener = tokio::net::UnixListener::bind(&path).map_err(|error| {
743                Error::internal(format!("failed to bind {}: {error}", path.display()))
744            })?;
745            // Unix sockets have no `SocketAddr`; readiness hooks get a placeholder.
746            let placeholder: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder address");
747            Ok((listener, placeholder, Some(path.display().to_string())))
748        })
749        .await
750    }
751
752    /// Shared serve lifecycle: startup, build, bind (via `bind`), readiness, the
753    /// accept loop, and shutdown. `bind` runs *after* startup so the socket is only
754    /// exposed once startup hooks have completed.
755    async fn serve_listener<S, L, F, Fut>(mut self, shutdown: S, bind: F) -> Result<()>
756    where
757        S: std::future::Future<Output = ()>,
758        L: crate::server::IncomingListener,
759        F: FnOnce() -> Fut,
760        Fut: std::future::Future<Output = Result<(L, SocketAddr, Option<String>)>>,
761    {
762        // Install the global logging subscriber first, so the whole lifecycle is
763        // logged. The handle is kept alive for the duration of the run.
764        // Clone (not take) so `build` can still read it for the request-log flag.
765        let logger_config = self.logger_config.clone().unwrap_or_default();
766        let _logger_handle = install_logging(&logger_config);
767        let log = Logger::framework("Tork");
768
769        log.info("Starting Tork application").emit();
770        self.validate_lifecycle().inspect_err(log_boot_error)?;
771        self.run_startup().await.inspect_err(log_boot_error)?;
772
773        let mut lifespan = std::mem::take(&mut self.lifespan);
774        let on_shutdown = std::mem::take(&mut self.on_shutdown);
775        let on_ready = std::mem::take(&mut self.on_ready);
776
777        let app = Arc::new(self.build().inspect_err(log_boot_error)?);
778
779        // Log each mapped route, NestJS-style.
780        let explorer = Logger::framework("RouterExplorer");
781        for route in app.matcher().routes() {
782            explorer
783                .info(format!("Mapped {{{} {}}}", route.method(), route.path()))
784                .emit();
785        }
786
787        // Bind only after startup, so clients cannot connect before the app is ready.
788        let (listener, ready_addr, unix_display) = bind().await.inspect_err(log_boot_error)?;
789
790        for hook in &on_ready {
791            hook(ReadyContext::new(ready_addr))
792                .await
793                .inspect_err(log_boot_error)?;
794        }
795
796        let running_on = match unix_display {
797            Some(path) => format!("unix:{path}"),
798            None => {
799                #[cfg(feature = "tls")]
800                let scheme = if app.tls_acceptor().is_some() { "https" } else { "http" };
801                #[cfg(not(feature = "tls"))]
802                let scheme = "http";
803                format!("{scheme}://{ready_addr}")
804            }
805        };
806        Logger::framework("Tork")
807            .info(format!("Application is running on {running_on}"))
808            .emit();
809
810        run_with_shutdown(app, listener, shutdown).await;
811
812        log.info("Shutting down").emit();
813        run_shutdown(&mut lifespan, &on_shutdown).await;
814        Ok(())
815    }
816
817    /// Runs the startup phase: lifespans in registration order (rolling back the
818    /// already-started ones on failure), or the `on_startup` hooks. Mutates the
819    /// state map with each lifespan's registered resources.
820    async fn run_startup(&mut self) -> Result<()> {
821        if !self.lifespan.is_empty() {
822            for index in 0..self.lifespan.len() {
823                let ctx = LifespanContext::new();
824                if let Err(error) = self.lifespan[index].startup(ctx, &mut self.state).await {
825                    for started in (0..index).rev() {
826                        let _ = self.lifespan[started].shutdown().await;
827                    }
828                    return Err(error);
829                }
830            }
831        } else {
832            for hook in &self.on_startup {
833                hook().await?;
834            }
835        }
836        Ok(())
837    }
838
839    /// Builds the application for in-process testing.
840    ///
841    /// Runs the startup phase (lifespans or `on_startup` hooks) and finalizes the
842    /// app without binding a socket, returning a [`TestApp`] that the
843    /// [`TestClient`](crate::testing::TestClient) drives. The `on_ready` hooks are
844    /// not run, since there is no bound address.
845    pub async fn build_test(self) -> Result<TestApp> {
846        self.build_test_with(|_| {}).await
847    }
848
849    /// Builds the application for testing, applying `apply` to the state map after
850    /// startup (so test overrides win) and before the app is finalized. Used by the
851    /// test client builder to inject resource and dependency overrides.
852    pub(crate) async fn build_test_with(
853        mut self,
854        apply: impl FnOnce(&mut StateMap),
855    ) -> Result<TestApp> {
856        self.validate_lifecycle()?;
857        self.run_startup().await?;
858        apply(&mut self.state);
859        let lifespan = std::mem::take(&mut self.lifespan);
860        let on_shutdown = std::mem::take(&mut self.on_shutdown);
861        let inner = Arc::new(self.build()?);
862        Ok(TestApp {
863            inner,
864            lifespan,
865            on_shutdown,
866        })
867    }
868}
869
870/// Logs a fatal boot error under the framework context.
871fn log_boot_error(error: &Error) {
872    Logger::framework("Tork")
873        .error(error.message().to_owned())
874        .field("code", error.code())
875        .emit();
876}
877
878/// Runs the shutdown phase: lifespans in reverse order, or the `on_shutdown`
879/// hooks. Errors are logged rather than propagated, as shutdown is best-effort.
880async fn run_shutdown(lifespan: &mut [Box<dyn ErasedLifespan>], on_shutdown: &[Hook]) {
881    let log = Logger::framework("Lifecycle");
882    if !lifespan.is_empty() {
883        for cell in lifespan.iter_mut().rev() {
884            if let Err(error) = cell.shutdown().await {
885                log.error(format!("shutdown failed: {}", error.message()))
886                    .emit();
887            }
888        }
889    } else {
890        for hook in on_shutdown {
891            if let Err(error) = hook().await {
892                log.error(format!("shutdown hook failed: {}", error.message()))
893                    .emit();
894            }
895        }
896    }
897}
898
899/// An application built for in-process testing.
900///
901/// Produced by [`App::build_test`] and consumed by
902/// [`TestClient`](crate::testing::TestClient). Holds the finalized app plus the
903/// lifespans and `on_shutdown` hooks so [`shutdown`](TestApp::shutdown) can run
904/// the teardown that a real server would run on stop.
905pub struct TestApp {
906    pub(crate) inner: Arc<AppInner>,
907    pub(crate) lifespan: Vec<Box<dyn ErasedLifespan>>,
908    pub(crate) on_shutdown: Vec<Hook>,
909}
910
911impl TestApp {
912    /// Runs the shutdown phase (lifespans in reverse, or `on_shutdown` hooks).
913    pub async fn shutdown(mut self) -> Result<()> {
914        run_shutdown(&mut self.lifespan, &self.on_shutdown).await;
915        Ok(())
916    }
917}
918
919/// The finalized application: shared state plus a compiled route matcher.
920///
921/// This is the value shared across all connections. It is cheap to clone behind
922/// an `Arc` and is what the server hands each request to via
923/// [`dispatch`](AppInner::dispatch).
924pub struct AppInner {
925    state: AppStateRef,
926    matcher: Matcher,
927    middleware: Arc<[Arc<dyn Middleware>]>,
928    on_request: Arc<[RequestHook]>,
929    on_response: Arc<[ResponseHook]>,
930    on_error: Arc<[ErrorHook]>,
931    on_validation_error: Arc<[ValidationErrorHook]>,
932    on_panic: Arc<[PanicHook]>,
933    catch_panics: bool,
934    request_logs: bool,
935    exception_handlers: Arc<HashMap<TypeId, ExceptionHandlerFn>>,
936    /// Broadcasts a shutdown signal to in-flight WebSocket connections so they
937    /// can close cleanly instead of being abruptly dropped.
938    ws_shutdown: tokio::sync::watch::Sender<bool>,
939    /// Closes a connection after this long with no read/write activity.
940    idle_timeout: Option<Duration>,
941    /// Deadline for a client to send the full request head; bounds slowloris.
942    header_read_timeout: Option<Duration>,
943    /// HTTP/1 connection tuning applied to the server builder.
944    http1_config: Option<Http1Config>,
945    /// HTTP/2 connection tuning applied to the server builder.
946    http2_config: Option<Http2Config>,
947    /// The rustls acceptor used to terminate TLS, when configured.
948    #[cfg(feature = "tls")]
949    tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
950}
951
952impl AppInner {
953    /// Returns the configured request-head read timeout, if any.
954    pub(crate) fn header_read_timeout(&self) -> Option<Duration> {
955        self.header_read_timeout
956    }
957
958    /// Returns the configured connection idle timeout, if any.
959    pub(crate) fn idle_timeout(&self) -> Option<Duration> {
960        self.idle_timeout
961    }
962
963    /// Returns the HTTP/1 connection tuning, if any.
964    pub(crate) fn http1_config(&self) -> Option<&Http1Config> {
965        self.http1_config.as_ref()
966    }
967
968    /// Returns the HTTP/2 connection tuning, if any.
969    pub(crate) fn http2_config(&self) -> Option<&Http2Config> {
970        self.http2_config.as_ref()
971    }
972
973    /// Returns the TLS acceptor used to terminate TLS, if configured.
974    #[cfg(feature = "tls")]
975    pub(crate) fn tls_acceptor(&self) -> Option<&tokio_rustls::TlsAcceptor> {
976        self.tls_acceptor.as_ref()
977    }
978
979    /// Returns the shared application state.
980    pub fn state(&self) -> &AppStateRef {
981        &self.state
982    }
983
984    /// Returns the compiled route matcher.
985    pub fn matcher(&self) -> &Matcher {
986        &self.matcher
987    }
988
989    /// Signals in-flight WebSocket connections to close cleanly.
990    ///
991    /// Called by the server when graceful shutdown begins; idempotent.
992    pub fn begin_ws_shutdown(&self) {
993        let _ = self.ws_shutdown.send(true);
994    }
995
996    /// Runs the middleware chain and dispatches the request.
997    ///
998    /// This is the entrypoint the server calls per request. `on_request` hooks run
999    /// before the chain and `on_response` hooks run after it. The middleware chain
1000    /// wraps [`dispatch`](AppInner::dispatch); a middleware error is rendered into a
1001    /// response (through the error hooks and any exception handler) so the
1002    /// connection is never torn down.
1003    pub async fn handle(self: Arc<Self>, request: Request) -> Response {
1004        // Build request metadata once, only if some hook or handler needs it.
1005        let info = self
1006            .needs_request_info()
1007            .then(|| request_info(request.method(), request.uri(), request.headers(), None));
1008        let start = (!self.on_response.is_empty()).then(Instant::now);
1009
1010        if let Some(info) = &info {
1011            for hook in self.on_request.iter() {
1012                hook(RequestEvent::new(info.clone())).await;
1013            }
1014        }
1015
1016        let next = Next::new(self.clone(), self.middleware.clone());
1017        let response = match next.run(request).await {
1018            Ok(response) => response,
1019            // A middleware-level error is rendered here, after the chain. No route
1020            // matched at this point, so only the app-global hooks apply.
1021            Err(error) => match &info {
1022                Some(info) => self.render_error(error, info, None).await,
1023                None => error.into_response(),
1024            },
1025        };
1026
1027        if let Some(info) = &info {
1028            let status = response.status();
1029            let elapsed = start.map(|start| start.elapsed()).unwrap_or_default();
1030            for hook in self.on_response.iter() {
1031                hook(ResponseEvent::new(info.clone(), status, elapsed)).await;
1032            }
1033        }
1034
1035        response
1036    }
1037
1038    /// Reports whether any registered hook or handler needs request metadata.
1039    ///
1040    /// When nothing observes the request, metadata is never built and errors are
1041    /// rendered directly, keeping the hook machinery zero-cost when unused.
1042    pub(crate) fn needs_request_info(&self) -> bool {
1043        !self.on_request.is_empty()
1044            || !self.on_response.is_empty()
1045            || !self.on_error.is_empty()
1046            || !self.on_validation_error.is_empty()
1047            || !self.on_panic.is_empty()
1048            || !self.exception_handlers.is_empty()
1049    }
1050
1051    /// Reports whether the panic boundary is enabled.
1052    pub(crate) fn catch_panics(&self) -> bool {
1053        self.catch_panics
1054    }
1055
1056    /// Reports whether the automatic HTTP request log is enabled.
1057    pub(crate) fn request_logs(&self) -> bool {
1058        self.request_logs
1059    }
1060
1061    /// Runs the panic hooks for a caught handler panic.
1062    pub(crate) async fn fire_panic(&self, info: &RequestInfo, message: &str) {
1063        for hook in self.on_panic.iter() {
1064            hook(PanicEvent::new(info.clone(), message.to_owned())).await;
1065        }
1066    }
1067
1068    /// Renders an error into a response, running the error hooks and any matching
1069    /// exception handler first.
1070    ///
1071    /// A validation failure fires `on_validation_error`; every other error fires
1072    /// `on_error`. The app-global hooks run first, then the matched route's scoped
1073    /// hooks (when `route` is `Some`). If the error carries a typed source with a
1074    /// registered exception handler, that handler produces the response; otherwise
1075    /// the default problem-details rendering is used.
1076    pub(crate) async fn render_error(
1077        &self,
1078        error: Error,
1079        info: &RequestInfo,
1080        route: Option<&Route>,
1081    ) -> Response {
1082        if error.is_validation() {
1083            for hook in self.on_validation_error.iter() {
1084                hook(ValidationErrorEvent::new(
1085                    info.clone(),
1086                    error.details().to_vec(),
1087                ))
1088                .await;
1089            }
1090            if let Some(route) = route {
1091                fire_validation_hooks(route.validation_hooks(), info, &error).await;
1092            }
1093        } else {
1094            for hook in self.on_error.iter() {
1095                hook(ErrorEvent::new(
1096                    info.clone(),
1097                    error.kind().status(),
1098                    error.static_code(),
1099                    error.message().to_owned(),
1100                ))
1101                .await;
1102            }
1103            if let Some(route) = route {
1104                fire_error_hooks(route.error_hooks(), info, &error).await;
1105            }
1106        }
1107
1108        if let Some(type_id) = error.source_type() {
1109            if let Some(handler) = self.exception_handlers.get(&type_id) {
1110                return handler(error, ErrorContext::new(info.clone())).await;
1111            }
1112        }
1113
1114        error.into_response()
1115    }
1116}
1117
1118/// Fires a slice of scoped `on_request` hooks in order.
1119pub(crate) async fn fire_request_hooks(hooks: &[SharedRequestHook], info: &RequestInfo) {
1120    for hook in hooks {
1121        hook(RequestEvent::new(info.clone())).await;
1122    }
1123}
1124
1125/// Fires a slice of scoped `on_response` hooks in reverse (innermost first).
1126pub(crate) async fn fire_response_hooks(
1127    hooks: &[SharedResponseHook],
1128    info: &RequestInfo,
1129    status: StatusCode,
1130    elapsed: Duration,
1131) {
1132    for hook in hooks.iter().rev() {
1133        hook(ResponseEvent::new(info.clone(), status, elapsed)).await;
1134    }
1135}
1136
1137/// Fires a slice of scoped `on_error` hooks in order.
1138async fn fire_error_hooks(hooks: &[SharedErrorHook], info: &RequestInfo, error: &Error) {
1139    for hook in hooks {
1140        hook(ErrorEvent::new(
1141            info.clone(),
1142            error.kind().status(),
1143            error.static_code(),
1144            error.message().to_owned(),
1145        ))
1146        .await;
1147    }
1148}
1149
1150/// Fires a slice of scoped `on_validation_error` hooks in order.
1151async fn fire_validation_hooks(
1152    hooks: &[SharedValidationErrorHook],
1153    info: &RequestInfo,
1154    error: &Error,
1155) {
1156    for hook in hooks {
1157        hook(ValidationErrorEvent::new(
1158            info.clone(),
1159            error.details().to_vec(),
1160        ))
1161        .await;
1162    }
1163}
1164
1165/// Builds request metadata for the hook events from a request head.
1166pub(crate) fn request_info(
1167    method: &Method,
1168    uri: &Uri,
1169    headers: &HeaderMap,
1170    route: Option<String>,
1171) -> RequestInfo {
1172    let request_id = headers
1173        .get(REQUEST_ID_HEADER)
1174        .and_then(|value| value.to_str().ok())
1175        .map(Arc::<str>::from);
1176    RequestInfo::new(
1177        method.clone(),
1178        Arc::from(uri.path()),
1179        route.map(Arc::<str>::from),
1180        request_id,
1181    )
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186    use super::*;
1187    use crate::Resources;
1188    use std::sync::atomic::{AtomicBool, Ordering};
1189
1190    static STARTED: AtomicBool = AtomicBool::new(false);
1191    static STOPPED: AtomicBool = AtomicBool::new(false);
1192
1193    #[derive(Clone)]
1194    struct Boot;
1195
1196    impl Resources for Boot {
1197        fn register(&self, _registry: &mut StateMap) {}
1198    }
1199
1200    impl Lifespan for Boot {
1201        async fn startup(_ctx: LifespanContext) -> Result<Self> {
1202            STARTED.store(true, Ordering::SeqCst);
1203            Ok(Boot)
1204        }
1205
1206        async fn shutdown(self) -> Result<()> {
1207            STOPPED.store(true, Ordering::SeqCst);
1208            Ok(())
1209        }
1210    }
1211
1212    #[tokio::test]
1213    async fn serve_runs_startup_then_shutdown() {
1214        // An immediately-ready shutdown future stops the accept loop at once.
1215        App::new()
1216            .lifespan::<Boot>()
1217            .serve_with_shutdown("127.0.0.1:0", async {})
1218            .await
1219            .unwrap();
1220
1221        assert!(STARTED.load(Ordering::SeqCst), "startup should have run");
1222        assert!(STOPPED.load(Ordering::SeqCst), "shutdown should have run");
1223    }
1224
1225    #[test]
1226    fn lifespan_with_event_hooks_is_a_conflict() {
1227        let error = App::new()
1228            .lifespan::<Boot>()
1229            .on_startup(|| async { Ok(()) })
1230            .build()
1231            .err()
1232            .expect("lifespan plus on_startup should conflict");
1233
1234        assert_eq!(error.code(), "LIFECYCLE_CONFLICT");
1235        assert!(
1236            error
1237                .message()
1238                .contains("Use either lifespan or event hooks"),
1239            "message: {}",
1240            error.message()
1241        );
1242    }
1243
1244    #[test]
1245    fn app_default_equals_new() {
1246        // Default must produce the same internal state as new().
1247        assert!(App::default().build().is_ok());
1248    }
1249
1250    #[test]
1251    fn websocket_config_is_accepted_by_builder() {
1252        let app = App::new()
1253            .websocket_config(crate::ws::WebSocketConfig::new().idle_timeout_secs(5))
1254            .build();
1255        assert!(app.is_ok());
1256    }
1257
1258    #[test]
1259    fn upload_config_is_accepted_by_builder() {
1260        let app = App::new()
1261            .upload_config(crate::multipart::UploadConfig::new().max_file_size(1024))
1262            .build();
1263        assert!(app.is_ok());
1264    }
1265
1266    #[test]
1267    fn on_ws_connect_and_disconnect_hooks_are_accepted_by_builder() {
1268        let app = App::new()
1269            .on_ws_connect(|_info| async {})
1270            .on_ws_disconnect(|_info| async {})
1271            .build();
1272        assert!(app.is_ok());
1273    }
1274
1275    #[test]
1276    fn logger_config_is_accepted_by_builder() {
1277        let app = App::new()
1278            .logger(crate::logging::LoggerConfig::new())
1279            .build();
1280        assert!(app.is_ok());
1281    }
1282
1283    #[test]
1284    fn app_supports_multiple_distinct_state_types() {
1285        struct Greeting(&'static str);
1286        struct Counter(u32);
1287
1288        let app = App::new()
1289            .state(Greeting("hello"))
1290            .state(Counter(42))
1291            .build()
1292            .expect("app builds");
1293
1294        // Both distinct state types are retrievable, each keyed by its own type.
1295        let greeting = app.state().get::<Greeting>().expect("greeting registered");
1296        let counter = app.state().get::<Counter>().expect("counter registered");
1297        assert_eq!(greeting.0, "hello");
1298        assert_eq!(counter.0, 42);
1299    }
1300}
1301
1302#[cfg(test)]
1303mod hook_tests {
1304    use super::*;
1305    use crate::body::box_body;
1306    use crate::extract::RequestContext;
1307    use crate::response::empty;
1308    use crate::router::{HandlerFn, Route};
1309    use crate::ErrorDetail;
1310    use bytes::Bytes;
1311    use http::StatusCode;
1312    use http_body_util::Full;
1313    use std::sync::Mutex;
1314
1315    type Log = Arc<Mutex<Vec<String>>>;
1316
1317    fn log() -> Log {
1318        Arc::new(Mutex::new(Vec::new()))
1319    }
1320
1321    fn entries(log: &Log) -> Vec<String> {
1322        log.lock().unwrap().clone()
1323    }
1324
1325    fn request(method: Method, uri: &str) -> Request {
1326        http::Request::builder()
1327            .method(method)
1328            .uri(uri)
1329            .body(box_body(Full::new(Bytes::new())))
1330            .unwrap()
1331    }
1332
1333    /// A route whose handler returns the error produced by `make`.
1334    fn failing_route(make: fn() -> Error) -> Router {
1335        let handler: HandlerFn = Arc::new(
1336            move |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1337                Box::pin(async move { Err(make()) })
1338            },
1339        );
1340        Router::new().route(Route::new(Method::GET, "/", handler))
1341    }
1342
1343    fn ok_handler() -> HandlerFn {
1344        Arc::new(
1345            |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1346                Box::pin(async { Ok(empty(StatusCode::OK)) })
1347            },
1348        )
1349    }
1350
1351    fn ok_route() -> Router {
1352        Router::new().route(Route::new(Method::GET, "/", ok_handler()))
1353    }
1354
1355    fn panicking_route() -> Router {
1356        let handler: HandlerFn = Arc::new(
1357            |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
1358                Box::pin(async { panic!("handler boom") })
1359            },
1360        );
1361        Router::new().route(Route::new(Method::GET, "/", handler))
1362    }
1363
1364    #[tokio::test]
1365    async fn on_error_fires_for_a_missing_route() {
1366        let seen = log();
1367        let recorder = seen.clone();
1368        let app = App::new()
1369            .on_error(move |event| {
1370                let recorder = recorder.clone();
1371                let code = event.code().to_owned();
1372                async move { recorder.lock().unwrap().push(code) }
1373            })
1374            .build()
1375            .unwrap();
1376
1377        let response = app.dispatch(request(Method::GET, "/missing")).await;
1378        assert_eq!(response.status(), StatusCode::NOT_FOUND);
1379        assert_eq!(entries(&seen), vec!["NOT_FOUND".to_owned()]);
1380    }
1381
1382    #[tokio::test]
1383    async fn validation_error_fires_only_the_validation_hook() {
1384        fn validation_error() -> Error {
1385            Error::unprocessable("invalid")
1386                .with_code("VALIDATION_ERROR")
1387                .with_details(vec![ErrorDetail::new("name", "TOO_SHORT", "too short")])
1388        }
1389
1390        let errors = log();
1391        let validations = log();
1392        let error_rec = errors.clone();
1393        let validation_rec = validations.clone();
1394
1395        let app = App::new()
1396            .include_router(failing_route(validation_error))
1397            .on_error(move |_event| {
1398                let rec = error_rec.clone();
1399                async move { rec.lock().unwrap().push("error".to_owned()) }
1400            })
1401            .on_validation_error(move |event| {
1402                let rec = validation_rec.clone();
1403                let fields = event.details().len();
1404                async move { rec.lock().unwrap().push(format!("validation:{fields}")) }
1405            })
1406            .build()
1407            .unwrap();
1408
1409        let response = app.dispatch(request(Method::GET, "/")).await;
1410        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
1411        assert_eq!(entries(&validations), vec!["validation:1".to_owned()]);
1412        assert!(
1413            entries(&errors).is_empty(),
1414            "on_error must not fire for validation"
1415        );
1416    }
1417
1418    #[derive(Debug)]
1419    struct SampleError;
1420    impl std::fmt::Display for SampleError {
1421        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1422            f.write_str("sample failure")
1423        }
1424    }
1425    impl std::error::Error for SampleError {}
1426
1427    #[tokio::test]
1428    async fn exception_handler_replaces_the_response() {
1429        fn sample_error() -> Error {
1430            Error::internal("wrapped").with_source(SampleError)
1431        }
1432
1433        let app = App::new()
1434            .include_router(failing_route(sample_error))
1435            .exception_handler::<SampleError, _, _>(|error, _ctx| async move {
1436                // The recovered value is the original typed error.
1437                assert_eq!(error.to_string(), "sample failure");
1438                empty(StatusCode::SERVICE_UNAVAILABLE)
1439            })
1440            .build()
1441            .unwrap();
1442
1443        let response = app.dispatch(request(Method::GET, "/")).await;
1444        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1445    }
1446
1447    #[tokio::test]
1448    async fn request_hooks_run_in_registration_order() {
1449        let seen = log();
1450        let first = seen.clone();
1451        let second = seen.clone();
1452
1453        let app = Arc::new(
1454            App::new()
1455                .include_router(ok_route())
1456                .on_request(move |_event| {
1457                    let rec = first.clone();
1458                    async move { rec.lock().unwrap().push("first".to_owned()) }
1459                })
1460                .on_request(move |_event| {
1461                    let rec = second.clone();
1462                    async move { rec.lock().unwrap().push("second".to_owned()) }
1463                })
1464                .build()
1465                .unwrap(),
1466        );
1467
1468        let response = app.handle(request(Method::GET, "/")).await;
1469        assert_eq!(response.status(), StatusCode::OK);
1470        assert_eq!(
1471            entries(&seen),
1472            vec!["first".to_owned(), "second".to_owned()]
1473        );
1474    }
1475
1476    #[tokio::test]
1477    async fn catch_panics_converts_a_panic_into_a_500() {
1478        let seen = log();
1479        let recorder = seen.clone();
1480        let app = App::new()
1481            .include_router(panicking_route())
1482            .catch_panics()
1483            .on_panic(move |event| {
1484                let recorder = recorder.clone();
1485                let message = event.message().to_owned();
1486                async move { recorder.lock().unwrap().push(message) }
1487            })
1488            .build()
1489            .unwrap();
1490
1491        let response = app.dispatch(request(Method::GET, "/")).await;
1492        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
1493        assert_eq!(entries(&seen), vec!["handler boom".to_owned()]);
1494    }
1495
1496    #[tokio::test]
1497    #[should_panic(expected = "handler boom")]
1498    async fn without_catch_panics_a_panic_propagates() {
1499        let app = App::new()
1500            .propagate_panics()
1501            .include_router(panicking_route())
1502            .build()
1503            .unwrap();
1504        let _ = app.dispatch(request(Method::GET, "/")).await;
1505    }
1506
1507    #[tokio::test]
1508    async fn scoped_on_request_fires_only_for_its_router() {
1509        let seen = log();
1510        let recorder = seen.clone();
1511        let scoped = Router::new()
1512            .route(Route::new(Method::GET, "/a", ok_handler()))
1513            .on_request(move |_event| {
1514                let recorder = recorder.clone();
1515                async move { recorder.lock().unwrap().push("a".to_owned()) }
1516            });
1517        let plain = Router::new().route(Route::new(Method::GET, "/b", ok_handler()));
1518        let app = App::new()
1519            .include_router(scoped)
1520            .include_router(plain)
1521            .build()
1522            .unwrap();
1523
1524        let _ = app.dispatch(request(Method::GET, "/a")).await;
1525        let _ = app.dispatch(request(Method::GET, "/b")).await;
1526        assert_eq!(entries(&seen), vec!["a".to_owned()]);
1527    }
1528
1529    #[tokio::test]
1530    async fn scoped_on_error_runs_after_the_global_hook() {
1531        let seen = log();
1532        let global = seen.clone();
1533        let scoped = seen.clone();
1534
1535        let router = failing_route(|| Error::not_found("missing")).on_error(move |_event| {
1536            let scoped = scoped.clone();
1537            async move { scoped.lock().unwrap().push("scoped".to_owned()) }
1538        });
1539        let app = App::new()
1540            .on_error(move |_event| {
1541                let global = global.clone();
1542                async move { global.lock().unwrap().push("global".to_owned()) }
1543            })
1544            .include_router(router)
1545            .build()
1546            .unwrap();
1547
1548        let _ = app.dispatch(request(Method::GET, "/")).await;
1549        assert_eq!(
1550            entries(&seen),
1551            vec!["global".to_owned(), "scoped".to_owned()]
1552        );
1553    }
1554
1555    #[tokio::test]
1556    async fn scoped_on_response_hooks_fire_in_reverse() {
1557        let seen = log();
1558        let first = seen.clone();
1559        let second = seen.clone();
1560        let router = Router::new()
1561            .route(Route::new(Method::GET, "/", ok_handler()))
1562            .on_response(move |_event| {
1563                let first = first.clone();
1564                async move { first.lock().unwrap().push("first".to_owned()) }
1565            })
1566            .on_response(move |_event| {
1567                let second = second.clone();
1568                async move { second.lock().unwrap().push("second".to_owned()) }
1569            });
1570        let app = App::new().include_router(router).build().unwrap();
1571
1572        let _ = app.dispatch(request(Method::GET, "/")).await;
1573        assert_eq!(
1574            entries(&seen),
1575            vec!["second".to_owned(), "first".to_owned()]
1576        );
1577    }
1578}