1use std::convert::Infallible;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use hyper_util::rt::{TokioExecutor, TokioIo};
10use tokio::net::TcpListener;
11
12use typeway_core::ApiSpec;
13
14use crate::body::BoxBody;
15use crate::router::{Router, RouterService};
16use crate::serves::Serves;
17
18pub struct Server<A: ApiSpec> {
34 router: Arc<Router>,
35 _api: PhantomData<A>,
36}
37
38impl<A: ApiSpec> Server<A> {
39 pub fn new<H: Serves<A>>(handlers: H) -> Self {
43 let mut router = Router::new();
44 handlers.register(&mut router);
45 Server {
46 router: Arc::new(router),
47 _api: PhantomData,
48 }
49 }
50
51 pub(crate) fn from_router(router: Arc<Router>) -> Self {
55 Server {
56 router,
57 _api: PhantomData,
58 }
59 }
60
61 pub fn nest(self, prefix: &str) -> Self {
76 self.router.set_prefix(prefix);
77 self
78 }
79
80 pub fn max_body_size(self, max: usize) -> Self {
85 self.router.set_max_body_size(max);
86 self
87 }
88
89 pub fn with_state<T: Clone + Send + Sync + 'static>(self, state: T) -> Self {
91 self.router.set_state_injector(Arc::new(move |ext| {
92 ext.insert(state.clone());
93 }));
94 self
95 }
96
97 #[cfg(feature = "openapi")]
111 pub fn with_openapi(self, title: &str, version: &str) -> Self
112 where
113 A: typeway_openapi::ApiToSpec,
114 {
115 let spec = A::to_spec(title, version);
116 let spec_json = std::sync::Arc::new(
117 serde_json::to_string_pretty(&spec).expect("OpenAPI spec serialization failed"),
118 );
119
120 let router = &self.router;
121
122 let spec_json_str =
123 serde_json::to_string(&spec).expect("OpenAPI spec serialization failed");
124
125 router.add_route(
126 http::Method::GET,
127 "/openapi.json".to_string(),
128 crate::openapi::exact_match(&["openapi.json"]),
129 crate::openapi::spec_handler(spec_json.clone()),
130 );
131
132 router.add_route(
133 http::Method::GET,
134 "/docs".to_string(),
135 crate::openapi::exact_match(&["docs"]),
136 crate::openapi::docs_handler(title, version, &spec_json_str),
137 );
138
139 self
140 }
141
142 #[cfg(feature = "openapi")]
169 pub fn with_openapi_docs(
170 self,
171 title: &str,
172 version: &str,
173 docs: &[typeway_core::HandlerDoc],
174 ) -> Self
175 where
176 A: typeway_openapi::ApiToSpec,
177 {
178 let mut spec = A::to_spec(title, version);
179 typeway_openapi::apply_handler_docs(&mut spec, docs);
180
181 let spec_json = std::sync::Arc::new(
182 serde_json::to_string_pretty(&spec).expect("OpenAPI spec serialization failed"),
183 );
184
185 let router = &self.router;
186
187 let spec_json_str =
188 serde_json::to_string(&spec).expect("OpenAPI spec serialization failed");
189
190 router.add_route(
191 http::Method::GET,
192 "/openapi.json".to_string(),
193 crate::openapi::exact_match(&["openapi.json"]),
194 crate::openapi::spec_handler(spec_json.clone()),
195 );
196
197 router.add_route(
198 http::Method::GET,
199 "/docs".to_string(),
200 crate::openapi::exact_match(&["docs"]),
201 crate::openapi::docs_handler(title, version, &spec_json_str),
202 );
203
204 self
205 }
206
207 pub fn with_static_files(self, prefix: &str, dir: impl Into<std::path::PathBuf>) -> Self {
221 let dir: std::path::PathBuf = dir.into();
222 let prefix_segments: Vec<String> = prefix
223 .split('/')
224 .filter(|s| !s.is_empty())
225 .map(|s| s.to_string())
226 .collect();
227 let prefix_len = prefix_segments.len();
228
229 let router = &self.router;
230
231 let dir = Arc::new(dir);
232 let prefix_segs = Arc::new(prefix_segments);
233
234 router.add_route(
236 http::Method::GET,
237 format!("{prefix}/{{*path}}"),
238 {
239 let prefix_segs = prefix_segs.clone();
240 Box::new(move |segments: &[&str]| {
241 segments.len() >= prefix_segs.len()
243 && segments[..prefix_segs.len()]
244 .iter()
245 .zip(prefix_segs.iter())
246 .all(|(a, b)| *a == b.as_str())
247 })
248 },
249 {
250 let dir = dir.clone();
251 std::sync::Arc::new(move |parts: http::request::Parts, _body: bytes::Bytes| {
252 let dir = dir.clone();
253 Box::pin(async move {
254 let path = parts.uri.path();
255 let file_path: String = path
257 .splitn(prefix_len + 2, '/')
258 .skip(prefix_len + 1)
259 .collect::<Vec<_>>()
260 .join("/");
261
262 if file_path.contains("..") {
264 let mut res = http::Response::new(crate::body::body_from_string(
265 "Forbidden".to_string(),
266 ));
267 *res.status_mut() = http::StatusCode::FORBIDDEN;
268 return res;
269 }
270
271 let full_path = if file_path.is_empty() {
272 dir.join("index.html")
274 } else {
275 let p = dir.join(&file_path);
276 if p.is_dir() {
278 p.join("index.html")
279 } else {
280 p
281 }
282 };
283
284 match tokio::fs::read(&full_path).await {
285 Ok(contents) => {
286 let mime = mime_from_path(&full_path);
287 let body =
288 crate::body::body_from_bytes(bytes::Bytes::from(contents));
289 let mut res = http::Response::new(body);
290 if let Ok(val) = http::HeaderValue::from_str(mime) {
291 res.headers_mut().insert(http::header::CONTENT_TYPE, val);
292 }
293 res
294 }
295 Err(_) => {
296 let mut res = http::Response::new(crate::body::body_from_string(
297 "Not Found".to_string(),
298 ));
299 *res.status_mut() = http::StatusCode::NOT_FOUND;
300 res
301 }
302 }
303 }) as crate::handler::ResponseFuture
304 })
305 },
306 );
307
308 self
309 }
310
311 pub fn with_spa_fallback(self, index_path: impl Into<std::path::PathBuf>) -> Self {
326 let index_path: std::path::PathBuf = index_path.into();
327
328 let html = match std::fs::read_to_string(&index_path) {
330 Ok(contents) => Arc::new(contents),
331 Err(e) => {
332 tracing::warn!(
333 "WARNING: SPA fallback file not found: {} ({}). \
334 Unmatched routes will show an error page.",
335 index_path.display(),
336 e
337 );
338 let error_page = format!(
339 "<!DOCTYPE html><html><body>\
340 <h1>Frontend Not Available</h1>\
341 <p>The SPA fallback file <code>{}</code> could not be loaded: {}</p>\
342 <p>If running locally, build the frontend first:</p>\
343 <pre>cd examples/realworld/frontend\nelm make src/Main.elm --output=public/elm.js</pre>\
344 </body></html>",
345 index_path.display(),
346 e
347 );
348 Arc::new(error_page)
349 }
350 };
351
352 self.set_fallback_raw(Arc::new(move |req| {
353 let html = html.clone();
354 let path = req.uri().path().to_string();
355 Box::pin(async move {
356 let last_segment = path.rsplit('/').next().unwrap_or("");
359 if last_segment.contains('.') {
360 let mut res =
361 http::Response::new(crate::body::body_from_string("Not Found".to_string()));
362 *res.status_mut() = http::StatusCode::NOT_FOUND;
363 return res;
364 }
365
366 let body = crate::body::body_from_string(html.to_string());
367 let mut res = http::Response::new(body);
368 res.headers_mut().insert(
369 http::header::CONTENT_TYPE,
370 http::HeaderValue::from_static("text/html; charset=utf-8"),
371 );
372 res
373 })
374 }));
375
376 self
377 }
378
379 pub(crate) fn set_fallback_raw(&self, fallback: crate::router::FallbackService) {
383 let router = &self.router;
384 router.set_fallback(fallback);
385 }
386
387 pub fn with_fallback<S>(self, service: S) -> Self
404 where
405 S: tower_service::Service<
406 http::Request<hyper::body::Incoming>,
407 Response = http::Response<BoxBody>,
408 Error = Infallible,
409 > + Clone
410 + Send
411 + Sync
412 + 'static,
413 S::Future: Send + 'static,
414 {
415 self.set_fallback_raw(Arc::new(
416 move |req: http::Request<hyper::body::Incoming>| {
417 let mut svc = service.clone();
418 Box::pin(async move {
419 tower_service::Service::call(&mut svc, req)
420 .await
421 .unwrap_or_else(|e| match e {})
422 })
423 },
424 ));
425 self
426 }
427
428 #[cfg(feature = "grpc")]
447 pub fn with_grpc(self, service_name: &str, package: &str) -> crate::grpc::GrpcServer<A>
448 where
449 A: typeway_grpc::CollectRpcs + typeway_grpc::GrpcReady,
450 {
451 crate::grpc::make_grpc_server(self.router, service_name, package)
452 }
453
454 pub fn layer<L>(self, layer: L) -> LayeredServer<L::Service>
473 where
474 L: tower_layer::Layer<RouterService>,
475 L::Service: tower_service::Service<
476 http::Request<hyper::body::Incoming>,
477 Response = http::Response<BoxBody>,
478 Error = Infallible,
479 > + Clone
480 + Send
481 + 'static,
482 <L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
483 Send + 'static,
484 {
485 let router = self.router.clone();
486 let svc = RouterService::new(self.router);
487 let layered = layer.layer(svc);
488 LayeredServer {
489 service: layered,
490 router,
491 }
492 }
493
494 #[cfg(feature = "tls")]
505 pub async fn serve_tls(
506 self,
507 addr: SocketAddr,
508 tls: crate::tls::TlsConfig,
509 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
510 let listener = TcpListener::bind(addr).await?;
511 tracing::info!("Listening on https://{addr}");
512 let router = self.router.clone();
513 crate::tls::serve_tls_loop(listener, tls, move || {
514 hyper_util::service::TowerToHyperService::new(RouterService::new(router.clone()))
515 })
516 .await
517 }
518
519 pub fn into_service(self) -> RouterService {
521 RouterService::new(self.router)
522 }
523
524 pub fn into_router(self) -> Router {
526 Arc::try_unwrap(self.router).unwrap_or_else(|_| {
527 panic!("cannot unwrap router — it has been cloned");
528 })
529 }
530
531 pub async fn serve(
533 self,
534 addr: SocketAddr,
535 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
536 let listener = TcpListener::bind(addr).await?;
537 tracing::info!("Listening on http://{addr}");
538 self.serve_with_shutdown(listener, std::future::pending())
539 .await
540 }
541
542 pub async fn serve_with_shutdown(
559 self,
560 listener: TcpListener,
561 shutdown: impl Future<Output = ()> + Send,
562 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
563 tokio::pin!(shutdown);
564
565 loop {
566 tokio::select! {
567 result = listener.accept() => {
568 let (stream, _) = result?;
569 let io = TokioIo::new(stream);
570 let svc = RouterService::new(self.router.clone());
571 let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
572
573 tokio::task::spawn(async move {
574 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
575 .serve_connection(io, hyper_svc)
576 .await
577 {
578 tracing::debug!("Connection closed: {e}");
579 }
580 });
581 }
582 () = &mut shutdown => {
583 tracing::info!("Shutting down gracefully...");
584 return Ok(());
585 }
586 }
587 }
588 }
589}
590
591pub struct LayeredServer<S> {
595 pub service: S,
597 pub(crate) router: Arc<Router>,
599}
600
601impl<S> LayeredServer<S> {
602 pub fn with_state<T: Clone + Send + Sync + 'static>(self, state: T) -> Self {
604 self.router.set_state_injector(Arc::new(move |ext| {
605 ext.insert(state.clone());
606 }));
607 self
608 }
609
610 pub fn max_body_size(self, max: usize) -> Self {
612 self.router.set_max_body_size(max);
613 self
614 }
615
616 pub fn nest(self, prefix: &str) -> Self {
618 self.router.set_prefix(prefix);
619 self
620 }
621
622 pub fn with_static_files(self, prefix: &str, dir: impl Into<std::path::PathBuf>) -> Self {
624 let dir: std::path::PathBuf = dir.into();
626 let prefix_segments: Vec<String> = prefix
627 .split('/')
628 .filter(|s| !s.is_empty())
629 .map(|s| s.to_string())
630 .collect();
631 let prefix_len = prefix_segments.len();
632 let dir = Arc::new(dir);
633 let prefix_segs = Arc::new(prefix_segments);
634
635 self.router.add_route(
636 http::Method::GET,
637 format!("{prefix}/{{*path}}"),
638 {
639 let prefix_segs = prefix_segs.clone();
640 Box::new(move |segments: &[&str]| {
641 segments.len() >= prefix_segs.len()
642 && segments[..prefix_segs.len()]
643 .iter()
644 .zip(prefix_segs.iter())
645 .all(|(a, b)| *a == b.as_str())
646 })
647 },
648 {
649 let dir = dir.clone();
650 std::sync::Arc::new(move |parts: http::request::Parts, _body: bytes::Bytes| {
651 let dir = dir.clone();
652 Box::pin(async move {
653 let path = parts.uri.path();
654 let file_path: String = path
655 .splitn(prefix_len + 2, '/')
656 .skip(prefix_len + 1)
657 .collect::<Vec<_>>()
658 .join("/");
659 if file_path.contains("..") {
660 let mut res = http::Response::new(crate::body::body_from_string(
661 "Forbidden".to_string(),
662 ));
663 *res.status_mut() = http::StatusCode::FORBIDDEN;
664 return res;
665 }
666 let full_path = if file_path.is_empty() {
667 dir.join("index.html")
668 } else {
669 let p = dir.join(&file_path);
670 if p.is_dir() {
671 p.join("index.html")
672 } else {
673 p
674 }
675 };
676 match tokio::fs::read(&full_path).await {
677 Ok(contents) => {
678 let mime = mime_from_path(&full_path);
679 let body =
680 crate::body::body_from_bytes(bytes::Bytes::from(contents));
681 let mut res = http::Response::new(body);
682 if let Ok(val) = http::HeaderValue::from_str(mime) {
683 res.headers_mut().insert(http::header::CONTENT_TYPE, val);
684 }
685 res
686 }
687 Err(_) => {
688 let mut res = http::Response::new(crate::body::body_from_string(
689 "Not Found".to_string(),
690 ));
691 *res.status_mut() = http::StatusCode::NOT_FOUND;
692 res
693 }
694 }
695 }) as crate::handler::ResponseFuture
696 })
697 },
698 );
699 self
700 }
701
702 pub fn with_spa_fallback(self, index_path: impl Into<std::path::PathBuf>) -> Self {
704 let index_path: std::path::PathBuf = index_path.into();
705 let html = match std::fs::read_to_string(&index_path) {
706 Ok(contents) => Arc::new(contents),
707 Err(e) => {
708 tracing::warn!(
709 "WARNING: SPA fallback file not found: {} ({})",
710 index_path.display(),
711 e
712 );
713 Arc::new(format!(
714 "<!DOCTYPE html><html><body>\
715 <h1>Frontend Not Available</h1>\
716 <p><code>{}</code>: {}</p></body></html>",
717 index_path.display(),
718 e
719 ))
720 }
721 };
722 self.router.set_fallback(Arc::new(move |req| {
723 let html = html.clone();
724 let path = req.uri().path().to_string();
725 Box::pin(async move {
726 let last_segment = path.rsplit('/').next().unwrap_or("");
727 if last_segment.contains('.') {
728 let mut res =
729 http::Response::new(crate::body::body_from_string("Not Found".to_string()));
730 *res.status_mut() = http::StatusCode::NOT_FOUND;
731 return res;
732 }
733 let body = crate::body::body_from_string(html.to_string());
734 let mut res = http::Response::new(body);
735 res.headers_mut().insert(
736 http::header::CONTENT_TYPE,
737 http::HeaderValue::from_static("text/html; charset=utf-8"),
738 );
739 res
740 })
741 }));
742 self
743 }
744}
745
746impl<S> LayeredServer<S>
747where
748 S: tower_service::Service<
749 http::Request<hyper::body::Incoming>,
750 Response = http::Response<BoxBody>,
751 Error = Infallible,
752 > + Clone
753 + Send
754 + 'static,
755 S::Future: Send + 'static,
756{
757 pub fn layer<L>(self, layer: L) -> LayeredServer<L::Service>
759 where
760 L: tower_layer::Layer<S>,
761 L::Service: tower_service::Service<
762 http::Request<hyper::body::Incoming>,
763 Response = http::Response<BoxBody>,
764 Error = Infallible,
765 > + Clone
766 + Send
767 + 'static,
768 <L::Service as tower_service::Service<http::Request<hyper::body::Incoming>>>::Future:
769 Send + 'static,
770 {
771 LayeredServer {
772 service: layer.layer(self.service),
773 router: self.router,
774 }
775 }
776
777 pub async fn serve(
779 self,
780 addr: SocketAddr,
781 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
782 let listener = TcpListener::bind(addr).await?;
783 tracing::info!("Listening on http://{addr}");
784 self.serve_with_shutdown(listener, std::future::pending())
785 .await
786 }
787
788 pub async fn serve_with_shutdown(
790 self,
791 listener: TcpListener,
792 shutdown: impl Future<Output = ()> + Send,
793 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
794 tokio::pin!(shutdown);
795
796 loop {
797 tokio::select! {
798 result = listener.accept() => {
799 let (stream, _) = result?;
800 let io = TokioIo::new(stream);
801 let svc = self.service.clone();
802 let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
803
804 tokio::task::spawn(async move {
805 if let Err(e) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
806 .serve_connection(io, hyper_svc)
807 .await
808 {
809 tracing::debug!("Connection closed: {e}");
810 }
811 });
812 }
813 () = &mut shutdown => {
814 tracing::info!("Shutting down gracefully...");
815 return Ok(());
816 }
817 }
818 }
819 }
820}
821
822pub async fn serve<A: ApiSpec, H: Serves<A>>(
830 addr: SocketAddr,
831 handlers: H,
832) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
833 Server::<A>::new(handlers).serve(addr).await
834}
835
836fn mime_from_path(path: &std::path::Path) -> &'static str {
838 match path.extension().and_then(|e| e.to_str()) {
839 Some("html") | Some("htm") => "text/html; charset=utf-8",
840 Some("css") => "text/css; charset=utf-8",
841 Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
842 Some("json") => "application/json",
843 Some("png") => "image/png",
844 Some("jpg") | Some("jpeg") => "image/jpeg",
845 Some("gif") => "image/gif",
846 Some("svg") => "image/svg+xml",
847 Some("ico") => "image/x-icon",
848 Some("woff") => "font/woff",
849 Some("woff2") => "font/woff2",
850 Some("ttf") => "font/ttf",
851 Some("txt") => "text/plain; charset=utf-8",
852 Some("xml") => "application/xml",
853 Some("wasm") => "application/wasm",
854 _ => "application/octet-stream",
855 }
856}