Skip to main content

typeway_server/
server.rs

1//! The type-safe [`Server`] builder and [`serve`] convenience function.
2
3use std::convert::Infallible;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use hyper_util::rt::{TokioExecutor, TokioIo};
10use tokio::net::TcpListener;
11
12use typeway_core::ApiSpec;
13
14use crate::body::BoxBody;
15use crate::router::{Router, RouterService};
16use crate::serves::Serves;
17
18/// A type-safe HTTP server parameterized by an API specification.
19///
20/// The `A` type parameter is the API type — a tuple of endpoints. The server
21/// ensures at compile time (via [`Serves`]) that every endpoint has a handler.
22///
23/// # Example
24///
25/// ```ignore
26/// type API = (
27///     GetEndpoint<path!("hello"), String>,
28/// );
29///
30/// let server = Server::<API>::new((hello_handler,));
31/// server.serve("127.0.0.1:3000".parse().unwrap()).await?;
32/// ```
33pub struct Server<A: ApiSpec> {
34    router: Arc<Router>,
35    _api: PhantomData<A>,
36}
37
38impl<A: ApiSpec> Server<A> {
39    /// Create a new server with handlers covering the full API.
40    ///
41    /// Fails to compile if the handler tuple doesn't match the API type.
42    pub fn new<H: Serves<A>>(handlers: H) -> Self {
43        let mut router = Router::new();
44        handlers.register(&mut router);
45        Server {
46            router: Arc::new(router),
47            _api: PhantomData,
48        }
49    }
50
51    /// Create a server from a pre-built router.
52    ///
53    /// Used internally by [`EffectfulServer::ready`](crate::effects::EffectfulServer).
54    pub(crate) fn from_router(router: Arc<Router>) -> Self {
55        Server {
56            router,
57            _api: PhantomData,
58        }
59    }
60
61    /// Set a path prefix for all routes in this server.
62    ///
63    /// Only requests whose path starts with the prefix will match. The prefix
64    /// is stripped before route matching.
65    ///
66    /// # Example
67    ///
68    /// ```ignore
69    /// // Routes are /api/v1/hello, /api/v1/users, etc.
70    /// Server::<API>::new(handlers)
71    ///     .nest("/api/v1")
72    ///     .serve(addr)
73    ///     .await?;
74    /// ```
75    pub fn nest(self, prefix: &str) -> Self {
76        self.router.set_prefix(prefix);
77        self
78    }
79
80    /// Set the maximum request body size in bytes.
81    ///
82    /// Bodies exceeding this limit are rejected with 413 Payload Too Large.
83    /// Default: 2 MiB (2,097,152 bytes).
84    pub fn max_body_size(self, max: usize) -> Self {
85        self.router.set_max_body_size(max);
86        self
87    }
88
89    /// Add shared state accessible via [`State<T>`](crate::extract::State) extractors.
90    pub fn with_state<T: Clone + Send + Sync + 'static>(self, state: T) -> Self {
91        self.router.set_state_injector(Arc::new(move |ext| {
92            ext.insert(state.clone());
93        }));
94        self
95    }
96
97    /// Enable OpenAPI spec serving at `/openapi.json` and Swagger UI at `/docs`.
98    ///
99    /// Requires `feature = "openapi"` and that the API type implements
100    /// [`ApiToSpec`](typeway_openapi::ApiToSpec).
101    ///
102    /// # Example
103    ///
104    /// ```ignore
105    /// Server::<API>::new(handlers)
106    ///     .with_openapi("My API", "1.0.0")
107    ///     .serve(addr)
108    ///     .await?;
109    /// ```
110    #[cfg(feature = "openapi")]
111    pub fn with_openapi(self, title: &str, version: &str) -> Self
112    where
113        A: typeway_openapi::ApiToSpec,
114    {
115        let spec = A::to_spec(title, version);
116        let spec_json = std::sync::Arc::new(
117            serde_json::to_string_pretty(&spec).expect("OpenAPI spec serialization failed"),
118        );
119
120        let router = &self.router;
121
122        let spec_json_str =
123            serde_json::to_string(&spec).expect("OpenAPI spec serialization failed");
124
125        router.add_route(
126            http::Method::GET,
127            "/openapi.json".to_string(),
128            crate::openapi::exact_match(&["openapi.json"]),
129            crate::openapi::spec_handler(spec_json.clone()),
130        );
131
132        router.add_route(
133            http::Method::GET,
134            "/docs".to_string(),
135            crate::openapi::exact_match(&["docs"]),
136            crate::openapi::docs_handler(title, version, &spec_json_str),
137        );
138
139        self
140    }
141
142    /// Enable OpenAPI spec serving with handler documentation applied.
143    ///
144    /// Like [`with_openapi`](Self::with_openapi), but patches the generated
145    /// spec with documentation metadata extracted from handler functions via
146    /// the `#[documented_handler]` attribute macro.
147    ///
148    /// # Example
149    ///
150    /// ```ignore
151    /// use typeway_macros::documented_handler;
152    ///
153    /// /// List all users.
154    /// ///
155    /// /// Returns a paginated list of users with optional filtering.
156    /// #[documented_handler(tags = "users")]
157    /// async fn list_users() -> Json<Vec<User>> { /* ... */ }
158    ///
159    /// /// Get a user by ID.
160    /// #[documented_handler(tags = "users")]
161    /// async fn get_user(id: Path<u32>) -> Json<User> { /* ... */ }
162    ///
163    /// Server::<API>::new(handlers)
164    ///     .with_openapi_docs("My API", "1.0.0", &[LIST_USERS_DOC, GET_USER_DOC])
165    ///     .serve(addr)
166    ///     .await?;
167    /// ```
168    #[cfg(feature = "openapi")]
169    pub fn with_openapi_docs(
170        self,
171        title: &str,
172        version: &str,
173        docs: &[typeway_core::HandlerDoc],
174    ) -> Self
175    where
176        A: typeway_openapi::ApiToSpec,
177    {
178        let mut spec = A::to_spec(title, version);
179        typeway_openapi::apply_handler_docs(&mut spec, docs);
180
181        let spec_json = std::sync::Arc::new(
182            serde_json::to_string_pretty(&spec).expect("OpenAPI spec serialization failed"),
183        );
184
185        let router = &self.router;
186
187        let spec_json_str =
188            serde_json::to_string(&spec).expect("OpenAPI spec serialization failed");
189
190        router.add_route(
191            http::Method::GET,
192            "/openapi.json".to_string(),
193            crate::openapi::exact_match(&["openapi.json"]),
194            crate::openapi::spec_handler(spec_json.clone()),
195        );
196
197        router.add_route(
198            http::Method::GET,
199            "/docs".to_string(),
200            crate::openapi::exact_match(&["docs"]),
201            crate::openapi::docs_handler(title, version, &spec_json_str),
202        );
203
204        self
205    }
206
207    /// Serve static files from a directory at a given URL prefix.
208    ///
209    /// Requests to `{prefix}/{path}` will serve files from `{dir}/{path}`.
210    /// 404 is returned if the file doesn't exist.
211    ///
212    /// # Example
213    ///
214    /// ```ignore
215    /// Server::<API>::new(handlers)
216    ///     .with_static_files("/static", "./public")
217    ///     .serve(addr)
218    ///     .await?;
219    /// ```
220    pub fn with_static_files(self, prefix: &str, dir: impl Into<std::path::PathBuf>) -> Self {
221        let dir: std::path::PathBuf = dir.into();
222        let prefix_segments: Vec<String> = prefix
223            .split('/')
224            .filter(|s| !s.is_empty())
225            .map(|s| s.to_string())
226            .collect();
227        let prefix_len = prefix_segments.len();
228
229        let router = &self.router;
230
231        let dir = Arc::new(dir);
232        let prefix_segs = Arc::new(prefix_segments);
233
234        // Add a catch-all route for the prefix (matches /static, /static/, /static/foo).
235        router.add_route(
236            http::Method::GET,
237            format!("{prefix}/{{*path}}"),
238            {
239                let prefix_segs = prefix_segs.clone();
240                Box::new(move |segments: &[&str]| {
241                    // Match: prefix exactly, or prefix + any file path
242                    segments.len() >= prefix_segs.len()
243                        && segments[..prefix_segs.len()]
244                            .iter()
245                            .zip(prefix_segs.iter())
246                            .all(|(a, b)| *a == b.as_str())
247                })
248            },
249            {
250                let dir = dir.clone();
251                std::sync::Arc::new(move |parts: http::request::Parts, _body: bytes::Bytes| {
252                    let dir = dir.clone();
253                    Box::pin(async move {
254                        let path = parts.uri.path();
255                        // Strip prefix to get the file path.
256                        let file_path: String = path
257                            .splitn(prefix_len + 2, '/')
258                            .skip(prefix_len + 1)
259                            .collect::<Vec<_>>()
260                            .join("/");
261
262                        // Prevent directory traversal.
263                        if file_path.contains("..") {
264                            let mut res = http::Response::new(crate::body::body_from_string(
265                                "Forbidden".to_string(),
266                            ));
267                            *res.status_mut() = http::StatusCode::FORBIDDEN;
268                            return res;
269                        }
270
271                        let full_path = if file_path.is_empty() {
272                            // /static or /static/ → try index.html
273                            dir.join("index.html")
274                        } else {
275                            let p = dir.join(&file_path);
276                            // If it's a directory, try index.html inside it
277                            if p.is_dir() {
278                                p.join("index.html")
279                            } else {
280                                p
281                            }
282                        };
283
284                        match tokio::fs::read(&full_path).await {
285                            Ok(contents) => {
286                                let mime = mime_from_path(&full_path);
287                                let body =
288                                    crate::body::body_from_bytes(bytes::Bytes::from(contents));
289                                let mut res = http::Response::new(body);
290                                if let Ok(val) = http::HeaderValue::from_str(mime) {
291                                    res.headers_mut().insert(http::header::CONTENT_TYPE, val);
292                                }
293                                res
294                            }
295                            Err(_) => {
296                                let mut res = http::Response::new(crate::body::body_from_string(
297                                    "Not Found".to_string(),
298                                ));
299                                *res.status_mut() = http::StatusCode::NOT_FOUND;
300                                res
301                            }
302                        }
303                    }) as crate::handler::ResponseFuture
304                })
305            },
306        );
307
308        self
309    }
310
311    /// Serve a file as the fallback for unmatched routes (SPA mode).
312    ///
313    /// When no API route matches, the given file (typically `index.html`)
314    /// is served. This enables client-side routing in single-page apps.
315    ///
316    /// # Example
317    ///
318    /// ```ignore
319    /// Server::<API>::new(handlers)
320    ///     .with_static_files("/static", "./public")
321    ///     .with_spa_fallback("./public/index.html")
322    ///     .serve(addr)
323    ///     .await?;
324    /// ```
325    pub fn with_spa_fallback(self, index_path: impl Into<std::path::PathBuf>) -> Self {
326        let index_path: std::path::PathBuf = index_path.into();
327
328        // Try to read the file at startup. If it fails, serve an error page instead.
329        let html = match std::fs::read_to_string(&index_path) {
330            Ok(contents) => Arc::new(contents),
331            Err(e) => {
332                tracing::warn!(
333                    "WARNING: SPA fallback file not found: {} ({}). \
334                     Unmatched routes will show an error page.",
335                    index_path.display(),
336                    e
337                );
338                let error_page = format!(
339                    "<!DOCTYPE html><html><body>\
340                     <h1>Frontend Not Available</h1>\
341                     <p>The SPA fallback file <code>{}</code> could not be loaded: {}</p>\
342                     <p>If running locally, build the frontend first:</p>\
343                     <pre>cd examples/realworld/frontend\nelm make src/Main.elm --output=public/elm.js</pre>\
344                     </body></html>",
345                    index_path.display(),
346                    e
347                );
348                Arc::new(error_page)
349            }
350        };
351
352        self.set_fallback_raw(Arc::new(move |req| {
353            let html = html.clone();
354            let path = req.uri().path().to_string();
355            Box::pin(async move {
356                // Don't serve SPA HTML for paths that look like file requests
357                // (contain a dot in the last segment, e.g. /foo/bar.js).
358                let last_segment = path.rsplit('/').next().unwrap_or("");
359                if last_segment.contains('.') {
360                    let mut res =
361                        http::Response::new(crate::body::body_from_string("Not Found".to_string()));
362                    *res.status_mut() = http::StatusCode::NOT_FOUND;
363                    return res;
364                }
365
366                let body = crate::body::body_from_string(html.to_string());
367                let mut res = http::Response::new(body);
368                res.headers_mut().insert(
369                    http::header::CONTENT_TYPE,
370                    http::HeaderValue::from_static("text/html; charset=utf-8"),
371                );
372                res
373            })
374        }));
375
376        self
377    }
378
379    /// Set a raw fallback function on the router.
380    ///
381    /// Used by `with_fallback` and `with_axum_fallback`.
382    pub(crate) fn set_fallback_raw(&self, fallback: crate::router::FallbackService) {
383        let router = &self.router;
384        router.set_fallback(fallback);
385    }
386
387    /// Set a fallback Tower service for requests that don't match any typeway route.
388    ///
389    /// This enables embedding an Axum router (or any Tower service) inside
390    /// a typeway server — the reverse of `into_axum_router()`.
391    ///
392    /// # Example
393    ///
394    /// ```ignore
395    /// let axum_routes = axum::Router::new()
396    ///     .route("/health", get(|| async { "ok" }));
397    ///
398    /// Server::<API>::new(handlers)
399    ///     .with_fallback(axum_routes)
400    ///     .serve(addr)
401    ///     .await?;
402    /// ```
403    pub fn with_fallback<S>(self, service: S) -> Self
404    where
405        S: tower_service::Service<
406                http::Request<hyper::body::Incoming>,
407                Response = http::Response<BoxBody>,
408                Error = Infallible,
409            > + Clone
410            + Send
411            + Sync
412            + 'static,
413        S::Future: Send + 'static,
414    {
415        self.set_fallback_raw(Arc::new(
416            move |req: http::Request<hyper::body::Incoming>| {
417                let mut svc = service.clone();
418                Box::pin(async move {
419                    tower_service::Service::call(&mut svc, req)
420                        .await
421                        .unwrap_or_else(|e| match e {})
422                })
423            },
424        ));
425        self
426    }
427
428    /// Create a unified REST + gRPC server.
429    ///
430    /// Returns a [`GrpcServer`](crate::grpc::GrpcServer) that serves both REST
431    /// and gRPC on the same port. Incoming requests with
432    /// `content-type: application/grpc*` are translated to REST calls via the
433    /// gRPC bridge, while all other requests are routed normally.
434    ///
435    /// Requires `feature = "grpc"`.
436    ///
437    /// # Example
438    ///
439    /// ```ignore
440    /// Server::<API>::new(handlers)
441    ///     .with_state(state)
442    ///     .with_grpc("UserService", "users.v1")
443    ///     .serve("0.0.0.0:3000".parse()?)
444    ///     .await?;
445    /// ```
446    #[cfg(feature = "grpc")]
447    pub fn with_grpc(self, service_name: &str, package: &str) -> crate::grpc::GrpcServer<A>
448    where
449        A: typeway_grpc::CollectRpcs + typeway_grpc::GrpcReady,
450    {
451        crate::grpc::make_grpc_server(self.router, service_name, package)
452    }
453
454    /// Apply a Tower middleware layer to the server.
455    ///
456    /// The layer wraps the entire router service. This is the same API
457    /// as Axum's `.layer()` — any `tower::Layer` that accepts the router
458    /// service type can be used.
459    ///
460    /// # Example
461    ///
462    /// ```ignore
463    /// use tower_http::trace::TraceLayer;
464    /// use tower_http::cors::CorsLayer;
465    ///
466    /// Server::<API>::new(handlers)
467    ///     .layer(TraceLayer::new_for_http())
468    ///     .layer(CorsLayer::permissive())
469    ///     .serve(addr)
470    ///     .await?;
471    /// ```
472    pub fn layer<L>(self, layer: L) -> LayeredServer<L::Service>
473    where
474        L: tower_layer::Layer<RouterService>,
475        L::Service: tower_service::Service<
476                http::Request<hyper::body::Incoming>,
477                Response = http::Response<BoxBody>,
478                Error = Infallible,
479            > + Clone
480            + Send
481            + 'static,
482        <L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
483            Send + 'static,
484    {
485        let router = self.router.clone();
486        let svc = RouterService::new(self.router);
487        let layered = layer.layer(svc);
488        LayeredServer {
489            service: layered,
490            router,
491        }
492    }
493
494    /// Start serving HTTPS with TLS.
495    ///
496    /// Requires `feature = "tls"`.
497    ///
498    /// ```ignore
499    /// let tls = TlsConfig::from_pem("cert.pem", "key.pem")?;
500    /// Server::<API>::new(handlers)
501    ///     .serve_tls("0.0.0.0:443".parse()?, tls)
502    ///     .await?;
503    /// ```
504    #[cfg(feature = "tls")]
505    pub async fn serve_tls(
506        self,
507        addr: SocketAddr,
508        tls: crate::tls::TlsConfig,
509    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
510        let listener = TcpListener::bind(addr).await?;
511        tracing::info!("Listening on https://{addr}");
512        let router = self.router.clone();
513        crate::tls::serve_tls_loop(listener, tls, move || {
514            hyper_util::service::TowerToHyperService::new(RouterService::new(router.clone()))
515        })
516        .await
517    }
518
519    /// Get the inner [`RouterService`] as a Tower service.
520    pub fn into_service(self) -> RouterService {
521        RouterService::new(self.router)
522    }
523
524    /// Get the inner router (for integration with other frameworks).
525    pub fn into_router(self) -> Router {
526        Arc::try_unwrap(self.router).unwrap_or_else(|_| {
527            panic!("cannot unwrap router — it has been cloned");
528        })
529    }
530
531    /// Start serving HTTP requests on the given address.
532    pub async fn serve(
533        self,
534        addr: SocketAddr,
535    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
536        let listener = TcpListener::bind(addr).await?;
537        tracing::info!("Listening on http://{addr}");
538        self.serve_with_shutdown(listener, std::future::pending())
539            .await
540    }
541
542    /// Start serving with graceful shutdown.
543    ///
544    /// Supports both HTTP/1.1 and HTTP/2 via automatic protocol detection.
545    /// The server stops accepting new connections when the `shutdown` future
546    /// completes. Existing connections are allowed to finish.
547    ///
548    /// # Example
549    ///
550    /// ```ignore
551    /// let server = Server::<API>::new(handlers);
552    /// let listener = TcpListener::bind("0.0.0.0:3000").await?;
553    ///
554    /// server.serve_with_shutdown(listener, async {
555    ///     tokio::signal::ctrl_c().await.ok();
556    /// }).await?;
557    /// ```
558    pub async fn serve_with_shutdown(
559        self,
560        listener: TcpListener,
561        shutdown: impl Future<Output = ()> + Send,
562    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
563        tokio::pin!(shutdown);
564
565        loop {
566            tokio::select! {
567                result = listener.accept() => {
568                    let (stream, _) = result?;
569                    let io = TokioIo::new(stream);
570                    let svc = RouterService::new(self.router.clone());
571                    let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
572
573                    tokio::task::spawn(async move {
574                        if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
575                            .serve_connection(io, hyper_svc)
576                            .await
577                        {
578                            tracing::debug!("Connection closed: {e}");
579                        }
580                    });
581                }
582                () = &mut shutdown => {
583                    tracing::info!("Shutting down gracefully...");
584                    return Ok(());
585                }
586            }
587        }
588    }
589}
590
591/// A server with Tower middleware layers applied.
592///
593/// Created by [`Server::layer`]. Supports further `.layer()` calls and `.serve()`.
594pub struct LayeredServer<S> {
595    /// The layered service. Exposed for advanced use cases (e.g., manual serving).
596    pub service: S,
597    /// Reference to the underlying router for post-layer configuration.
598    pub(crate) router: Arc<Router>,
599}
600
601impl<S> LayeredServer<S> {
602    /// Add shared state accessible via [`State<T>`](crate::extract::State) extractors.
603    pub fn with_state<T: Clone + Send + Sync + 'static>(self, state: T) -> Self {
604        self.router.set_state_injector(Arc::new(move |ext| {
605            ext.insert(state.clone());
606        }));
607        self
608    }
609
610    /// Set the maximum request body size.
611    pub fn max_body_size(self, max: usize) -> Self {
612        self.router.set_max_body_size(max);
613        self
614    }
615
616    /// Set a path prefix for all routes.
617    pub fn nest(self, prefix: &str) -> Self {
618        self.router.set_prefix(prefix);
619        self
620    }
621
622    /// Serve static files from a directory.
623    pub fn with_static_files(self, prefix: &str, dir: impl Into<std::path::PathBuf>) -> Self {
624        // Delegate to the shared router's add_route via the same logic as Server.
625        let dir: std::path::PathBuf = dir.into();
626        let prefix_segments: Vec<String> = prefix
627            .split('/')
628            .filter(|s| !s.is_empty())
629            .map(|s| s.to_string())
630            .collect();
631        let prefix_len = prefix_segments.len();
632        let dir = Arc::new(dir);
633        let prefix_segs = Arc::new(prefix_segments);
634
635        self.router.add_route(
636            http::Method::GET,
637            format!("{prefix}/{{*path}}"),
638            {
639                let prefix_segs = prefix_segs.clone();
640                Box::new(move |segments: &[&str]| {
641                    segments.len() >= prefix_segs.len()
642                        && segments[..prefix_segs.len()]
643                            .iter()
644                            .zip(prefix_segs.iter())
645                            .all(|(a, b)| *a == b.as_str())
646                })
647            },
648            {
649                let dir = dir.clone();
650                std::sync::Arc::new(move |parts: http::request::Parts, _body: bytes::Bytes| {
651                    let dir = dir.clone();
652                    Box::pin(async move {
653                        let path = parts.uri.path();
654                        let file_path: String = path
655                            .splitn(prefix_len + 2, '/')
656                            .skip(prefix_len + 1)
657                            .collect::<Vec<_>>()
658                            .join("/");
659                        if file_path.contains("..") {
660                            let mut res = http::Response::new(crate::body::body_from_string(
661                                "Forbidden".to_string(),
662                            ));
663                            *res.status_mut() = http::StatusCode::FORBIDDEN;
664                            return res;
665                        }
666                        let full_path = if file_path.is_empty() {
667                            dir.join("index.html")
668                        } else {
669                            let p = dir.join(&file_path);
670                            if p.is_dir() {
671                                p.join("index.html")
672                            } else {
673                                p
674                            }
675                        };
676                        match tokio::fs::read(&full_path).await {
677                            Ok(contents) => {
678                                let mime = mime_from_path(&full_path);
679                                let body =
680                                    crate::body::body_from_bytes(bytes::Bytes::from(contents));
681                                let mut res = http::Response::new(body);
682                                if let Ok(val) = http::HeaderValue::from_str(mime) {
683                                    res.headers_mut().insert(http::header::CONTENT_TYPE, val);
684                                }
685                                res
686                            }
687                            Err(_) => {
688                                let mut res = http::Response::new(crate::body::body_from_string(
689                                    "Not Found".to_string(),
690                                ));
691                                *res.status_mut() = http::StatusCode::NOT_FOUND;
692                                res
693                            }
694                        }
695                    }) as crate::handler::ResponseFuture
696                })
697            },
698        );
699        self
700    }
701
702    /// Serve a file as SPA fallback for unmatched routes.
703    pub fn with_spa_fallback(self, index_path: impl Into<std::path::PathBuf>) -> Self {
704        let index_path: std::path::PathBuf = index_path.into();
705        let html = match std::fs::read_to_string(&index_path) {
706            Ok(contents) => Arc::new(contents),
707            Err(e) => {
708                tracing::warn!(
709                    "WARNING: SPA fallback file not found: {} ({})",
710                    index_path.display(),
711                    e
712                );
713                Arc::new(format!(
714                    "<!DOCTYPE html><html><body>\
715                     <h1>Frontend Not Available</h1>\
716                     <p><code>{}</code>: {}</p></body></html>",
717                    index_path.display(),
718                    e
719                ))
720            }
721        };
722        self.router.set_fallback(Arc::new(move |req| {
723            let html = html.clone();
724            let path = req.uri().path().to_string();
725            Box::pin(async move {
726                let last_segment = path.rsplit('/').next().unwrap_or("");
727                if last_segment.contains('.') {
728                    let mut res =
729                        http::Response::new(crate::body::body_from_string("Not Found".to_string()));
730                    *res.status_mut() = http::StatusCode::NOT_FOUND;
731                    return res;
732                }
733                let body = crate::body::body_from_string(html.to_string());
734                let mut res = http::Response::new(body);
735                res.headers_mut().insert(
736                    http::header::CONTENT_TYPE,
737                    http::HeaderValue::from_static("text/html; charset=utf-8"),
738                );
739                res
740            })
741        }));
742        self
743    }
744}
745
746impl<S> LayeredServer<S>
747where
748    S: tower_service::Service<
749            http::Request<hyper::body::Incoming>,
750            Response = http::Response<BoxBody>,
751            Error = Infallible,
752        > + Clone
753        + Send
754        + 'static,
755    S::Future: Send + 'static,
756{
757    /// Apply another Tower middleware layer.
758    pub fn layer<L>(self, layer: L) -> LayeredServer<L::Service>
759    where
760        L: tower_layer::Layer<S>,
761        L::Service: tower_service::Service<
762                http::Request<hyper::body::Incoming>,
763                Response = http::Response<BoxBody>,
764                Error = Infallible,
765            > + Clone
766            + Send
767            + 'static,
768        <L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
769            Send + 'static,
770    {
771        LayeredServer {
772            service: layer.layer(self.service),
773            router: self.router,
774        }
775    }
776
777    /// Start serving HTTP requests on the given address.
778    pub async fn serve(
779        self,
780        addr: SocketAddr,
781    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
782        let listener = TcpListener::bind(addr).await?;
783        tracing::info!("Listening on http://{addr}");
784        self.serve_with_shutdown(listener, std::future::pending())
785            .await
786    }
787
788    /// Start serving with graceful shutdown.
789    pub async fn serve_with_shutdown(
790        self,
791        listener: TcpListener,
792        shutdown: impl Future<Output = ()> + Send,
793    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
794        tokio::pin!(shutdown);
795
796        loop {
797            tokio::select! {
798                result = listener.accept() => {
799                    let (stream, _) = result?;
800                    let io = TokioIo::new(stream);
801                    let svc = self.service.clone();
802                    let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
803
804                    tokio::task::spawn(async move {
805                        if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
806                            .serve_connection(io, hyper_svc)
807                            .await
808                        {
809                            tracing::debug!("Connection closed: {e}");
810                        }
811                    });
812                }
813                () = &mut shutdown => {
814                    tracing::info!("Shutting down gracefully...");
815                    return Ok(());
816                }
817            }
818        }
819    }
820}
821
822/// Convenience function to create and serve an API.
823///
824/// # Example
825///
826/// ```ignore
827/// serve::<API, _>("127.0.0.1:3000".parse().unwrap(), (handler1, handler2)).await?;
828/// ```
829pub async fn serve<A: ApiSpec, H: Serves<A>>(
830    addr: SocketAddr,
831    handlers: H,
832) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
833    Server::<A>::new(handlers).serve(addr).await
834}
835
836/// Guess MIME type from file extension.
837fn mime_from_path(path: &std::path::Path) -> &'static str {
838    match path.extension().and_then(|e| e.to_str()) {
839        Some("html") | Some("htm") => "text/html; charset=utf-8",
840        Some("css") => "text/css; charset=utf-8",
841        Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
842        Some("json") => "application/json",
843        Some("png") => "image/png",
844        Some("jpg") | Some("jpeg") => "image/jpeg",
845        Some("gif") => "image/gif",
846        Some("svg") => "image/svg+xml",
847        Some("ico") => "image/x-icon",
848        Some("woff") => "font/woff",
849        Some("woff2") => "font/woff2",
850        Some("ttf") => "font/ttf",
851        Some("txt") => "text/plain; charset=utf-8",
852        Some("xml") => "application/xml",
853        Some("wasm") => "application/wasm",
854        _ => "application/octet-stream",
855    }
856}