1use std::future::Future;
20#[cfg(not(feature = "compio"))]
21use std::path::PathBuf;
22#[cfg(not(feature = "compio"))]
23use std::pin::Pin;
24#[cfg(any(not(feature = "compio"), feature = "tls"))]
29use std::sync::Arc;
30use std::time::Duration;
31
32use tako_rs_core::router::Router;
33#[cfg(not(feature = "compio"))]
34use tokio::net::TcpListener;
35
36use crate::ServerConfig;
37
38pub struct ServerHandle {
48 shutdown: tokio_util::sync::CancellationToken,
49 done: tokio_util::sync::CancellationToken,
50 drain_timeout: Duration,
51}
52
53impl ServerHandle {
54 pub fn trigger(&self) {
56 self.shutdown.cancel();
57 }
58
59 pub async fn join(&self) {
65 self.done.cancelled().await;
66 }
67
68 pub async fn shutdown(self, _timeout: Duration) {
75 self.shutdown.cancel();
76 self.done.cancelled().await;
77 }
78
79 #[inline]
81 pub fn drain_timeout(&self) -> Duration {
82 self.drain_timeout
83 }
84}
85
86impl std::fmt::Debug for ServerHandle {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("ServerHandle")
89 .field("drain_timeout", &self.drain_timeout)
90 .finish_non_exhaustive()
91 }
92}
93
94pub async fn either<A, B>(a: A, b: B)
96where
97 A: Future<Output = ()>,
98 B: Future<Output = ()>,
99{
100 use futures_util::future::Either;
101 let a = std::pin::pin!(a);
102 let b = std::pin::pin!(b);
103 match futures_util::future::select(a, b).await {
104 Either::Left(_) | Either::Right(_) => {}
105 }
106}
107
108#[cfg(feature = "tls")]
115#[derive(Clone)]
116pub enum ClientAuth {
117 Optional(Arc<rustls::RootCertStore>),
119 Required(Arc<rustls::RootCertStore>),
121}
122
123#[cfg(feature = "tls")]
124impl std::fmt::Debug for ClientAuth {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 ClientAuth::Optional(_) => f.debug_tuple("Optional").field(&"<root_store>").finish(),
128 ClientAuth::Required(_) => f.debug_tuple("Required").field(&"<root_store>").finish(),
129 }
130 }
131}
132
133#[derive(Clone)]
142pub enum TlsCert {
143 PemPaths {
145 cert_path: String,
147 key_path: String,
149 #[cfg(feature = "tls")]
151 client_auth: Option<ClientAuth>,
152 },
153 #[cfg(feature = "tls")]
156 Der {
157 certs: Arc<Vec<rustls::pki_types::CertificateDer<'static>>>,
159 key: Arc<rustls::pki_types::PrivateKeyDer<'static>>,
161 client_auth: Option<ClientAuth>,
163 },
164 #[cfg(feature = "tls")]
168 Resolver {
169 resolver: Arc<dyn rustls::server::ResolvesServerCert>,
171 client_auth: Option<ClientAuth>,
173 },
174}
175
176impl std::fmt::Debug for TlsCert {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 match self {
179 TlsCert::PemPaths {
180 cert_path,
181 key_path,
182 ..
183 } => f
184 .debug_struct("PemPaths")
185 .field("cert_path", cert_path)
186 .field("key_path", key_path)
187 .finish_non_exhaustive(),
188 #[cfg(feature = "tls")]
189 TlsCert::Der { client_auth, .. } => f
190 .debug_struct("Der")
191 .field("client_auth", client_auth)
192 .finish_non_exhaustive(),
193 #[cfg(feature = "tls")]
194 TlsCert::Resolver { client_auth, .. } => f
195 .debug_struct("Resolver")
196 .field("client_auth", client_auth)
197 .finish_non_exhaustive(),
198 }
199 }
200}
201
202impl TlsCert {
203 pub fn pem_paths(cert: impl Into<String>, key: impl Into<String>) -> Self {
205 Self::PemPaths {
206 cert_path: cert.into(),
207 key_path: key.into(),
208 #[cfg(feature = "tls")]
209 client_auth: None,
210 }
211 }
212
213 #[cfg(feature = "tls")]
215 pub fn pem_paths_with_client_auth(
216 cert: impl Into<String>,
217 key: impl Into<String>,
218 client_auth: ClientAuth,
219 ) -> Self {
220 Self::PemPaths {
221 cert_path: cert.into(),
222 key_path: key.into(),
223 client_auth: Some(client_auth),
224 }
225 }
226
227 #[cfg(feature = "tls")]
229 pub fn der(
230 certs: Vec<rustls::pki_types::CertificateDer<'static>>,
231 key: rustls::pki_types::PrivateKeyDer<'static>,
232 ) -> Self {
233 Self::Der {
234 certs: Arc::new(certs),
235 key: Arc::new(key),
236 client_auth: None,
237 }
238 }
239
240 #[cfg(feature = "tls")]
243 pub fn resolver(resolver: Arc<dyn rustls::server::ResolvesServerCert>) -> Self {
244 Self::Resolver {
245 resolver,
246 client_auth: None,
247 }
248 }
249
250 #[cfg(feature = "tls")]
256 pub fn with_client_auth(mut self, auth: ClientAuth) -> Self {
257 match &mut self {
258 TlsCert::PemPaths { client_auth, .. }
259 | TlsCert::Der { client_auth, .. }
260 | TlsCert::Resolver { client_auth, .. } => *client_auth = Some(auth),
261 }
262 self
263 }
264}
265
266#[cfg(feature = "tls")]
291pub struct ReloadableResolver {
292 current: arc_swap::ArcSwap<rustls::sign::CertifiedKey>,
293}
294
295#[cfg(feature = "tls")]
296impl std::fmt::Debug for ReloadableResolver {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("ReloadableResolver").finish_non_exhaustive()
299 }
300}
301
302#[cfg(feature = "tls")]
303impl ReloadableResolver {
304 pub fn from_pem(cert_path: &str, key_path: &str) -> anyhow::Result<Self> {
306 let ck = build_certified_key(cert_path, key_path)?;
307 Ok(Self {
308 current: arc_swap::ArcSwap::from_pointee(ck),
309 })
310 }
311
312 pub fn reload_from_pem(&self, cert_path: &str, key_path: &str) -> anyhow::Result<()> {
317 let ck = build_certified_key(cert_path, key_path)?;
318 self.current.store(Arc::new(ck));
319 Ok(())
320 }
321
322 pub fn reload(&self, ck: rustls::sign::CertifiedKey) {
324 self.current.store(Arc::new(ck));
325 }
326}
327
328#[cfg(feature = "tls")]
329impl rustls::server::ResolvesServerCert for ReloadableResolver {
330 fn resolve(
331 &self,
332 _client_hello: rustls::server::ClientHello<'_>,
333 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
334 Some(self.current.load_full())
335 }
336}
337
338#[cfg(feature = "tls")]
339fn build_certified_key(
340 cert_path: &str,
341 key_path: &str,
342) -> anyhow::Result<rustls::sign::CertifiedKey> {
343 let certs = tako_rs_core::tls::load_certs(cert_path)?;
344 let key = tako_rs_core::tls::load_key(key_path)?;
345
346 let we_installed = if rustls::crypto::CryptoProvider::get_default().is_none() {
358 rustls::crypto::aws_lc_rs::default_provider()
359 .install_default()
360 .is_ok()
361 } else {
362 false
363 };
364 if !we_installed {
365 static WARNED: std::sync::Once = std::sync::Once::new();
370 WARNED.call_once(|| {
371 tracing::warn!(
372 "tako-server: a rustls CryptoProvider was already installed before \
373 `build_certified_key` ran — Tako will use that provider for key \
374 loading instead of installing aws-lc-rs. If signing behavior is \
375 not what you expect (e.g. h3 installed `ring` first), pin the \
376 provider at process startup with `rustls::crypto::aws_lc_rs::\
377 default_provider().install_default()` BEFORE constructing the \
378 server."
379 );
380 });
381 }
382 let provider = rustls::crypto::CryptoProvider::get_default().ok_or_else(|| {
383 anyhow::anyhow!(
384 "no rustls CryptoProvider installed — enable rustls's `aws_lc_rs` or `ring` feature"
385 )
386 })?;
387 let signer = provider
388 .key_provider
389 .load_private_key(key)
390 .map_err(|e| anyhow::anyhow!("failed to load signing key from '{key_path}': {e}"))?;
391 Ok(rustls::sign::CertifiedKey::new(certs, signer))
392}
393
394#[cfg(feature = "tls")]
401pub fn build_rustls_server_config(
402 cert: &TlsCert,
403 alpn: Vec<Vec<u8>>,
404) -> anyhow::Result<Arc<rustls::ServerConfig>> {
405 use rustls::ServerConfig as RustlsServerConfig;
406
407 let builder = RustlsServerConfig::builder();
408
409 let client_auth = match cert {
413 TlsCert::PemPaths { client_auth, .. }
414 | TlsCert::Der { client_auth, .. }
415 | TlsCert::Resolver { client_auth, .. } => client_auth.clone(),
416 };
417
418 let builder_with_auth = match client_auth {
419 Some(ClientAuth::Optional(roots)) => {
420 let verifier = rustls::server::WebPkiClientVerifier::builder(roots)
421 .allow_unauthenticated()
422 .build()
423 .map_err(|e| anyhow::anyhow!("WebPkiClientVerifier build failed: {e}"))?;
424 builder.with_client_cert_verifier(verifier)
425 }
426 Some(ClientAuth::Required(roots)) => {
427 let verifier = rustls::server::WebPkiClientVerifier::builder(roots)
428 .build()
429 .map_err(|e| anyhow::anyhow!("WebPkiClientVerifier build failed: {e}"))?;
430 builder.with_client_cert_verifier(verifier)
431 }
432 None => builder.with_no_client_auth(),
433 };
434
435 let mut config = match cert {
436 TlsCert::PemPaths {
437 cert_path,
438 key_path,
439 ..
440 } => {
441 let certs = tako_rs_core::tls::load_certs(cert_path)?;
442 let key = tako_rs_core::tls::load_key(key_path)?;
443 builder_with_auth
444 .with_single_cert(certs, key)
445 .map_err(|e| anyhow::anyhow!("rustls config build failed: {e}"))?
446 }
447 TlsCert::Der { certs, key, .. } => {
448 let certs = certs.as_ref().clone();
449 let key = key.as_ref().clone_key();
450 builder_with_auth
451 .with_single_cert(certs, key)
452 .map_err(|e| anyhow::anyhow!("rustls config build failed: {e}"))?
453 }
454 TlsCert::Resolver { resolver, .. } => builder_with_auth.with_cert_resolver(resolver.clone()),
455 };
456
457 config.alpn_protocols = alpn;
458 if config.alpn_protocols.iter().any(|p| p.as_slice() == b"h3") {
465 config.max_early_data_size = 0;
466 }
467 Ok(Arc::new(config))
468}
469
470#[cfg(not(feature = "compio"))]
472#[derive(Debug, Default, Clone)]
473pub struct ServerBuilder {
474 config: ServerConfig,
475 tls: Option<TlsCert>,
476}
477
478#[cfg(not(feature = "compio"))]
479impl ServerBuilder {
480 #[must_use]
482 pub fn config(mut self, config: ServerConfig) -> Self {
483 self.config = config;
484 self
485 }
486
487 #[must_use]
489 pub fn tls(mut self, cert: TlsCert) -> Self {
490 self.tls = Some(cert);
491 self
492 }
493
494 pub fn build(self) -> Server {
496 Server {
497 config: self.config,
498 tls: self.tls,
499 }
500 }
501}
502
503#[cfg(not(feature = "compio"))]
505#[derive(Debug, Clone)]
506pub struct Server {
507 config: ServerConfig,
508 #[cfg_attr(not(any(feature = "tls", feature = "http3")), allow(dead_code))]
512 tls: Option<TlsCert>,
513}
514
515#[cfg(not(feature = "compio"))]
516impl Server {
517 #[must_use]
519 pub fn builder() -> ServerBuilder {
520 ServerBuilder::default()
521 }
522
523 #[inline]
525 pub fn config(&self) -> &ServerConfig {
526 &self.config
527 }
528
529 pub fn spawn_http(&self, listener: TcpListener, router: Router) -> ServerHandle {
533 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
534 let config = self.config.clone();
535 spawn_done(handle.done.clone(), async move {
536 crate::server::serve_with_shutdown_and_config(listener, router, shutdown_fut, config).await;
537 });
538 handle
539 }
540
541 #[cfg(feature = "http2")]
543 pub fn spawn_h2c(&self, listener: TcpListener, router: Router) -> ServerHandle {
544 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
545 let config = self.config.clone();
546 spawn_done(handle.done.clone(), async move {
547 crate::server_h2c::serve_h2c_with_shutdown_and_config(listener, router, shutdown_fut, config)
548 .await;
549 });
550 handle
551 }
552
553 #[cfg(feature = "tls")]
559 pub fn spawn_tls(&self, listener: TcpListener, router: Router) -> ServerHandle {
560 let tls = self
561 .tls
562 .clone()
563 .expect("Server::spawn_tls requires a TlsCert (use builder().tls(...))");
564 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
565 let config = self.config.clone();
566 let alpn = tls_alpn_for_tcp();
567 spawn_done(handle.done.clone(), async move {
568 if let TlsCert::PemPaths {
571 cert_path,
572 key_path,
573 client_auth: None,
574 } = &tls
575 {
576 crate::server_tls::serve_tls_with_shutdown_and_config(
577 listener,
578 router,
579 Some(cert_path.as_str()),
580 Some(key_path.as_str()),
581 shutdown_fut,
582 config,
583 )
584 .await;
585 return;
586 }
587 let rustls_cfg = match build_rustls_server_config(&tls, alpn) {
588 Ok(c) => c,
589 Err(e) => {
590 tracing::error!("Server::spawn_tls: failed to build rustls config: {e}");
591 return;
592 }
593 };
594 crate::server_tls::serve_tls_with_rustls_config_and_shutdown(
595 listener,
596 router,
597 rustls_cfg,
598 shutdown_fut,
599 config,
600 )
601 .await;
602 });
603 handle
604 }
605
606 #[cfg(feature = "http3")]
609 pub fn spawn_h3(&self, addr: impl Into<String>, router: Router) -> ServerHandle {
610 let tls = self
611 .tls
612 .clone()
613 .expect("Server::spawn_h3 requires a TlsCert (use builder().tls(...))");
614 let addr = addr.into();
615 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
616 let config = self.config.clone();
617 spawn_done(handle.done.clone(), async move {
618 if let TlsCert::PemPaths {
619 cert_path,
620 key_path,
621 client_auth: None,
622 } = &tls
623 {
624 crate::server_h3::serve_h3_with_shutdown_and_config(
625 router,
626 &addr,
627 Some(cert_path.as_str()),
628 Some(key_path.as_str()),
629 shutdown_fut,
630 config,
631 )
632 .await;
633 return;
634 }
635 let rustls_cfg = match build_rustls_server_config(&tls, vec![b"h3".to_vec()]) {
636 Ok(c) => c,
637 Err(e) => {
638 tracing::error!("Server::spawn_h3: failed to build rustls config: {e}");
639 return;
640 }
641 };
642 crate::server_h3::serve_h3_with_rustls_config_and_shutdown(
643 router,
644 &addr,
645 rustls_cfg,
646 shutdown_fut,
647 config,
648 )
649 .await;
650 });
651 handle
652 }
653
654 #[cfg(unix)]
656 pub fn spawn_unix_http(&self, path: impl Into<PathBuf>, router: Router) -> ServerHandle {
657 let path = path.into();
658 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
659 let config = self.config.clone();
660 spawn_done(handle.done.clone(), async move {
661 crate::server_unix::serve_unix_http_with_shutdown_and_config(
662 path,
663 router,
664 shutdown_fut,
665 config,
666 )
667 .await;
668 });
669 handle
670 }
671
672 #[cfg(all(target_os = "linux", feature = "vsock"))]
675 pub fn spawn_vsock_http(&self, cid: u32, port: u32, router: Router) -> ServerHandle {
676 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
677 let config = self.config.clone();
678 spawn_done(handle.done.clone(), async move {
679 crate::server_vsock::serve_vsock_http_with_shutdown_and_config(
680 cid,
681 port,
682 router,
683 shutdown_fut,
684 config,
685 )
686 .await;
687 });
688 handle
689 }
690
691 pub fn spawn_proxy_protocol(&self, listener: TcpListener, router: Router) -> ServerHandle {
693 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
694 let config = self.config.clone();
695 spawn_done(handle.done.clone(), async move {
696 crate::proxy_protocol::serve_http_with_proxy_protocol_shutdown_and_config(
697 listener,
698 router,
699 shutdown_fut,
700 config,
701 )
702 .await;
703 });
704 handle
705 }
706
707 pub fn spawn_tcp_raw<F>(&self, addr: impl Into<String>, handler: F) -> ServerHandle
711 where
712 F: Fn(
713 tokio::net::TcpStream,
714 std::net::SocketAddr,
715 ) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
716 + Send
717 + Sync
718 + 'static,
719 {
720 let addr = addr.into();
721 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
722 spawn_done(handle.done.clone(), async move {
723 if let Err(e) = crate::server_tcp::serve_tcp_with_shutdown(&addr, handler, shutdown_fut).await
724 {
725 tracing::error!("raw TCP server error: {e}");
726 }
727 });
728 handle
729 }
730
731 pub fn spawn_udp_raw<F>(&self, addr: impl Into<String>, handler: F) -> ServerHandle
733 where
734 F: Fn(
735 Vec<u8>,
736 std::net::SocketAddr,
737 Arc<tokio::net::UdpSocket>,
738 ) -> Pin<Box<dyn Future<Output = ()> + Send>>
739 + Send
740 + Sync
741 + 'static,
742 {
743 let addr = addr.into();
744 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
745 spawn_done(handle.done.clone(), async move {
746 if let Err(e) = crate::server_udp::serve_udp_with_shutdown(&addr, handler, shutdown_fut).await
747 {
748 tracing::error!("raw UDP server error: {e}");
749 }
750 });
751 handle
752 }
753}
754
755#[cfg(feature = "compio")]
757#[derive(Debug, Default, Clone)]
758pub struct CompioServerBuilder {
759 config: ServerConfig,
760 #[cfg(feature = "compio-tls")]
763 tls: Option<TlsCert>,
764}
765
766#[cfg(feature = "compio")]
767impl CompioServerBuilder {
768 #[must_use]
770 pub fn config(mut self, config: ServerConfig) -> Self {
771 self.config = config;
772 self
773 }
774
775 #[cfg(feature = "compio-tls")]
777 #[must_use]
778 pub fn tls(mut self, cert: TlsCert) -> Self {
779 self.tls = Some(cert);
780 self
781 }
782
783 pub fn build(self) -> CompioServer {
785 CompioServer {
786 config: self.config,
787 #[cfg(feature = "compio-tls")]
788 tls: self.tls,
789 }
790 }
791}
792
793#[cfg(feature = "compio")]
798#[derive(Debug, Clone)]
799pub struct CompioServer {
800 config: ServerConfig,
801 #[cfg(feature = "compio-tls")]
805 tls: Option<TlsCert>,
806}
807
808#[cfg(feature = "compio")]
809impl CompioServer {
810 #[must_use]
812 pub fn builder() -> CompioServerBuilder {
813 CompioServerBuilder::default()
814 }
815
816 #[inline]
818 pub fn config(&self) -> &ServerConfig {
819 &self.config
820 }
821
822 pub fn spawn_http(&self, listener: compio::net::TcpListener, router: Router) -> ServerHandle {
824 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
825 let config = self.config.clone();
826 spawn_done_compio(handle.done.clone(), async move {
827 crate::server_compio::serve_with_shutdown_and_config(listener, router, shutdown_fut, config)
828 .await;
829 });
830 handle
831 }
832
833 #[cfg(feature = "compio-tls")]
835 pub fn spawn_tls(&self, listener: compio::net::TcpListener, router: Router) -> ServerHandle {
836 let tls = self
837 .tls
838 .clone()
839 .expect("CompioServer::spawn_tls requires a TlsCert (use builder().tls(...))");
840 let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
841 let config = self.config.clone();
842 let alpn = tls_alpn_for_tcp();
843 spawn_done_compio(handle.done.clone(), async move {
844 if let TlsCert::PemPaths {
845 cert_path,
846 key_path,
847 client_auth: None,
848 } = &tls
849 {
850 crate::server_tls_compio::serve_tls_with_shutdown_and_config(
851 listener,
852 router,
853 Some(cert_path.as_str()),
854 Some(key_path.as_str()),
855 shutdown_fut,
856 config,
857 )
858 .await;
859 return;
860 }
861 let rustls_cfg = match build_rustls_server_config(&tls, alpn) {
862 Ok(c) => c,
863 Err(e) => {
864 tracing::error!("CompioServer::spawn_tls: failed to build rustls config: {e}");
865 return;
866 }
867 };
868 crate::server_tls_compio::serve_tls_with_rustls_config_and_shutdown(
869 listener,
870 router,
871 rustls_cfg,
872 shutdown_fut,
873 config,
874 )
875 .await;
876 });
877 handle
878 }
879}
880
881#[cfg(feature = "tls")]
884#[inline]
885fn tls_alpn_for_tcp() -> Vec<Vec<u8>> {
886 #[cfg(feature = "http2")]
887 {
888 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
889 }
890 #[cfg(not(feature = "http2"))]
891 {
892 vec![b"http/1.1".to_vec()]
893 }
894}
895
896fn make_handle(
897 drain_timeout: Duration,
898) -> (ServerHandle, impl Future<Output = ()> + Send + 'static) {
899 let shutdown = tokio_util::sync::CancellationToken::new();
900 let done = tokio_util::sync::CancellationToken::new();
901 let shutdown_for_task = shutdown.clone();
902 let fut = async move {
907 shutdown_for_task.cancelled().await;
908 };
909 (
910 ServerHandle {
911 shutdown,
912 done,
913 drain_timeout,
914 },
915 fut,
916 )
917}
918
919#[cfg(not(feature = "compio"))]
920fn spawn_done<F>(done: tokio_util::sync::CancellationToken, fut: F)
921where
922 F: Future<Output = ()> + Send + 'static,
923{
924 tokio::spawn(async move {
925 fut.await;
926 done.cancel();
927 });
928}
929
930#[cfg(feature = "compio")]
931fn spawn_done_compio<F>(done: tokio_util::sync::CancellationToken, fut: F)
932where
933 F: Future<Output = ()> + 'static,
934{
935 compio::runtime::spawn(async move {
936 fut.await;
937 done.cancel();
938 })
939 .detach();
940}