1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use std::convert::Infallible;
21use std::io;
22use std::net::SocketAddr;
23use std::str::FromStr;
24use std::time::Duration;
25
26use hyper::server::conn::http1;
27use hyper::service::service_fn;
28use socket2::Domain;
29use socket2::Protocol;
30use socket2::Socket;
31use socket2::Type;
32use tako_rs_core::body::TakoBody;
33use tako_rs_core::conn_info::ConnInfo;
34use tako_rs_core::router::Router;
35use tokio::net::TcpListener;
36use tokio::runtime::Builder;
37use tokio::task::LocalSet;
38
39#[derive(Debug, Clone)]
41pub struct PerThreadConfig {
42 pub workers: usize,
44 pub pin_to_core: bool,
46 pub backlog: i32,
48 pub drain_timeout: Duration,
51}
52
53impl Default for PerThreadConfig {
54 fn default() -> Self {
55 Self {
56 workers: num_cpus(),
57 pin_to_core: cfg!(feature = "affinity"),
58 backlog: 1024,
59 drain_timeout: Duration::from_secs(30),
60 }
61 }
62}
63
64#[derive(Default)]
70struct BindStatus {
71 succeeded: std::sync::atomic::AtomicUsize,
73 failed: std::sync::atomic::AtomicUsize,
75 first_err: std::sync::Mutex<Option<io::Error>>,
80 notify: tokio::sync::Notify,
82}
83
84#[derive(Clone, Default)]
100pub struct PerThreadShutdown {
101 inner: tokio_util::sync::CancellationToken,
102 bind_status: std::sync::Arc<BindStatus>,
103}
104
105impl PerThreadShutdown {
106 #[must_use]
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn trigger(&self) {
115 self.inner.cancel();
116 }
117
118 pub async fn notified(&self) {
120 self.inner.cancelled().await;
121 }
122
123 pub(crate) fn report_bind_success(&self) {
125 self
126 .bind_status
127 .succeeded
128 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
129 self.bind_status.notify.notify_waiters();
130 }
131
132 pub(crate) fn report_bind_failure(&self, err: io::Error) {
136 {
137 let mut guard = self.bind_status.first_err.lock().unwrap();
138 if guard.is_none() {
139 *guard = Some(err);
140 }
141 }
142 self
143 .bind_status
144 .failed
145 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
146 self.bind_status.notify.notify_waiters();
147 }
148
149 pub async fn wait_for_bind_outcome(&self, total: usize) -> io::Result<()> {
154 use std::sync::atomic::Ordering;
155
156 loop {
157 let notified = self.bind_status.notify.notified();
160 tokio::pin!(notified);
161 notified.as_mut().enable();
162
163 let succ = self.bind_status.succeeded.load(Ordering::SeqCst);
164 let fail = self.bind_status.failed.load(Ordering::SeqCst);
165
166 if succ > 0 {
167 return Ok(());
168 }
169 if succ + fail >= total {
170 let err = self
171 .bind_status
172 .first_err
173 .lock()
174 .unwrap()
175 .take()
176 .unwrap_or_else(|| {
177 io::Error::other(format!("all {total} per-thread workers failed to bind"))
178 });
179 return Err(err);
180 }
181
182 notified.await;
183 }
184 }
185}
186
187fn num_cpus() -> usize {
188 std::thread::available_parallelism().map_or(1, std::num::NonZero::get)
189}
190
191#[cfg(feature = "compio")]
192fn compio_accept_backoff() -> Duration {
193 Duration::from_millis(5)
194}
195
196fn warn_reuseport_platform_once() {
200 static WARNED: std::sync::Once = std::sync::Once::new();
201 WARNED.call_once(|| {
202 #[cfg(target_os = "linux")]
203 {
204 }
206 #[cfg(all(unix, not(target_os = "linux")))]
207 {
208 tracing::warn!(
209 "tako-server-pt: SO_REUSEPORT is being used on a non-Linux Unix \
210 platform. The kernel typically sends incoming connections only to \
211 the most recent binder, so multi-worker thread-per-core mode will \
212 not load-balance correctly. Use a single worker or run on Linux."
213 );
214 }
215 #[cfg(windows)]
216 {
217 tracing::warn!(
218 "tako-server-pt: SO_REUSEPORT does not exist on Windows. Only the \
219 first worker will accept connections; subsequent worker binds will \
220 fail with EADDRINUSE. Use a single worker on Windows."
221 );
222 }
223 });
224}
225
226fn bind_reuseport_std(addr: SocketAddr, backlog: i32) -> io::Result<std::net::TcpListener> {
227 warn_reuseport_platform_once();
228 let domain = if addr.is_ipv4() {
229 Domain::IPV4
230 } else {
231 Domain::IPV6
232 };
233 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
234 socket.set_reuse_address(true)?;
235 #[cfg(unix)]
240 socket.set_reuse_port(true)?;
241 socket.set_nonblocking(true)?;
242 socket.bind(&addr.into())?;
243 socket.listen(backlog)?;
244 Ok(socket.into())
245}
246
247fn bind_reuseport(addr: SocketAddr, backlog: i32) -> io::Result<TcpListener> {
248 TcpListener::from_std(bind_reuseport_std(addr, backlog)?)
249}
250
251#[cfg(feature = "compio")]
252fn bind_reuseport_compio(addr: SocketAddr, backlog: i32) -> io::Result<compio::net::TcpListener> {
253 compio::net::TcpListener::from_std(bind_reuseport_std(addr, backlog)?)
254}
255
256pub fn serve_per_thread(addr: &str, router: Router, cfg: PerThreadConfig) -> io::Result<()> {
266 let workers = cfg.workers;
267 let (handle, shutdown) = spawn_per_thread(addr, router, cfg)?;
268 let rt = tokio::runtime::Builder::new_current_thread()
274 .enable_all()
275 .build()
276 .map_err(|e| io::Error::other(format!("ctrl-c runtime: {e}")))?;
277 let result: io::Result<()> = rt.block_on(async {
283 shutdown.wait_for_bind_outcome(workers).await?;
284 let _ = tokio::signal::ctrl_c().await;
285 Ok(())
286 });
287 shutdown.trigger();
288 for h in handle {
289 let _ = h.join();
290 }
291 result
292}
293
294pub fn spawn_per_thread(
302 addr: &str,
303 router: Router,
304 cfg: PerThreadConfig,
305) -> io::Result<(Vec<std::thread::JoinHandle<()>>, PerThreadShutdown)> {
306 let socket_addr =
307 SocketAddr::from_str(addr).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
308
309 let router: &'static Router = Box::leak(Box::new(router));
312
313 let shutdown = PerThreadShutdown::new();
314 let mut handles = Vec::with_capacity(cfg.workers);
315 for worker_id in 0..cfg.workers {
316 let cfg = cfg.clone();
317 let shutdown = shutdown.clone();
318 let h = std::thread::Builder::new()
319 .name(format!("tako-pt-{worker_id}"))
320 .spawn(move || worker_main(worker_id, socket_addr, router, cfg, shutdown))
321 .expect("spawn tako-pt worker");
322 handles.push(h);
323 }
324 Ok((handles, shutdown))
325}
326
327#[cfg_attr(not(feature = "affinity"), allow(unused_variables))]
333fn worker_main(
334 worker_id: usize,
335 addr: SocketAddr,
336 router: &'static Router,
337 cfg: PerThreadConfig,
338 shutdown: PerThreadShutdown,
339) {
340 #[cfg(feature = "affinity")]
341 if cfg.pin_to_core {
342 if let Some(ids) = core_affinity::get_core_ids() {
343 if let Some(id) = ids.get(worker_id) {
344 if !core_affinity::set_for_current(*id) {
345 tracing::warn!(
346 worker_id,
347 "pin_to_core: core_affinity::set_for_current returned false; running without affinity"
348 );
349 }
350 } else {
351 tracing::warn!(
352 worker_id,
353 available_cores = ids.len(),
354 "pin_to_core: worker_id exceeds available cores; running without affinity"
355 );
356 }
357 } else {
358 tracing::warn!(
359 worker_id,
360 "pin_to_core: core_affinity::get_core_ids() returned None; running without affinity"
361 );
362 }
363 }
364
365 let rt = match Builder::new_current_thread().enable_all().build() {
366 Ok(rt) => rt,
367 Err(e) => {
368 tracing::error!("worker {worker_id}: failed to build runtime: {e}");
369 shutdown.report_bind_failure(io::Error::other(format!(
373 "worker {worker_id}: failed to build runtime: {e}"
374 )));
375 return;
376 }
377 };
378
379 let local = LocalSet::new();
380 local.block_on(&rt, async move {
381 let listener = match bind_reuseport(addr, cfg.backlog) {
382 Ok(l) => {
383 shutdown.report_bind_success();
386 l
387 }
388 Err(e) => {
389 tracing::error!("worker {worker_id}: bind failed: {e}");
390 shutdown.report_bind_failure(e);
391 return;
392 }
393 };
394 tracing::debug!("tako-pt worker {worker_id} listening on {addr}");
395
396 let shutdown_fut = shutdown.notified();
397 tokio::pin!(shutdown_fut);
398
399 let mut connection_handles: tokio::task::JoinSet<()> = tokio::task::JoinSet::new();
406
407 loop {
408 tokio::select! {
409 accept = listener.accept() => {
410 let (stream, peer) = match accept {
411 Ok(v) => v,
412 Err(e) => {
413 tracing::warn!("worker {worker_id}: accept failed: {e}");
414 continue;
415 }
416 };
417 if let Err(e) = stream.set_nodelay(true) {
422 tracing::debug!("worker {worker_id}: set_nodelay failed for {peer}: {e}");
423 }
424 let io = hyper_util::rt::TokioIo::new(stream);
425
426 connection_handles.spawn_local(async move {
427 let svc = service_fn(move |mut req| async move {
428 req.extensions_mut().insert(peer);
429 req.extensions_mut().insert(ConnInfo::tcp(peer));
430 let resp = router.dispatch(req.map(TakoBody::incoming)).await;
431 Ok::<_, Infallible>(resp)
432 });
433
434 let mut http = http1::Builder::new();
435 http.keep_alive(true);
436 http.pipeline_flush(true);
437 if let Err(err) = http.serve_connection(io, svc).with_upgrades().await {
438 if err.is_incomplete_message() {
439 tracing::debug!("worker {worker_id}: client disconnected mid-message: {err}");
440 } else {
441 tracing::error!("worker {worker_id}: connection error: {err}");
442 }
443 }
444 });
445
446 while connection_handles.try_join_next().is_some() {}
451 }
452 () = &mut shutdown_fut => {
453 tracing::info!("worker {worker_id}: shutdown signalled, draining");
454 break;
455 }
456 }
457 }
458 let drain = tokio::time::timeout(cfg.drain_timeout, async {
463 while connection_handles.join_next().await.is_some() {}
464 });
465 let _ = drain.await;
466 });
467}
468
469#[cfg(feature = "compio")]
476#[cfg_attr(docsrs, doc(cfg(feature = "compio")))]
477pub fn serve_per_thread_compio(addr: &str, router: Router, cfg: PerThreadConfig) -> io::Result<()> {
478 let socket_addr =
479 SocketAddr::from_str(addr).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
480
481 let router: &'static Router = Box::leak(Box::new(router));
482
483 let workers = cfg.workers;
484 let shutdown = PerThreadShutdown::new();
485 let mut handles = Vec::with_capacity(cfg.workers);
486 for worker_id in 0..cfg.workers {
487 let cfg = cfg.clone();
488 let shutdown = shutdown.clone();
489 let h = std::thread::Builder::new()
490 .name(format!("tako-pt-compio-{worker_id}"))
491 .spawn(move || worker_main_compio(worker_id, socket_addr, router, cfg, shutdown))
492 .expect("spawn tako-pt-compio worker");
493 handles.push(h);
494 }
495
496 let rt = tokio::runtime::Builder::new_current_thread()
499 .enable_all()
500 .build()
501 .map_err(|e| io::Error::other(format!("ctrl-c runtime: {e}")))?;
502 let result: io::Result<()> = rt.block_on(async {
503 shutdown.wait_for_bind_outcome(workers).await?;
504 let _ = tokio::signal::ctrl_c().await;
505 Ok(())
506 });
507 shutdown.trigger();
508 for h in handles {
509 let _ = h.join();
510 }
511 result
512}
513
514#[cfg(feature = "compio")]
524struct PtConnGuard {
525 inflight: std::sync::Arc<std::sync::atomic::AtomicUsize>,
526 drain_notify: std::sync::Arc<tokio::sync::Notify>,
527}
528
529#[cfg(feature = "compio")]
530impl PtConnGuard {
531 fn new(
532 inflight: std::sync::Arc<std::sync::atomic::AtomicUsize>,
533 drain_notify: std::sync::Arc<tokio::sync::Notify>,
534 ) -> Self {
535 inflight.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
536 Self {
537 inflight,
538 drain_notify,
539 }
540 }
541}
542
543#[cfg(feature = "compio")]
544impl Drop for PtConnGuard {
545 fn drop(&mut self) {
546 self
547 .inflight
548 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
549 self.drain_notify.notify_waiters();
550 }
551}
552
553#[cfg(feature = "compio")]
554#[cfg_attr(not(feature = "affinity"), allow(unused_variables))]
555fn worker_main_compio(
556 worker_id: usize,
557 addr: SocketAddr,
558 router: &'static Router,
559 cfg: PerThreadConfig,
560 shutdown: PerThreadShutdown,
561) {
562 use std::sync::Arc;
563 use std::sync::atomic::AtomicUsize;
564 use std::sync::atomic::Ordering;
565
566 use cyper_core::HyperStream;
567 use tokio::sync::Notify;
568
569 #[cfg(feature = "affinity")]
570 if cfg.pin_to_core {
571 if let Some(ids) = core_affinity::get_core_ids() {
572 if let Some(id) = ids.get(worker_id) {
573 if !core_affinity::set_for_current(*id) {
574 tracing::warn!(
575 worker_id,
576 "pin_to_core: core_affinity::set_for_current returned false; running without affinity"
577 );
578 }
579 } else {
580 tracing::warn!(
581 worker_id,
582 available_cores = ids.len(),
583 "pin_to_core: worker_id exceeds available cores; running without affinity"
584 );
585 }
586 } else {
587 tracing::warn!(
588 worker_id,
589 "pin_to_core: core_affinity::get_core_ids() returned None; running without affinity"
590 );
591 }
592 }
593
594 let rt = match compio::runtime::RuntimeBuilder::new().build() {
595 Ok(rt) => rt,
596 Err(e) => {
597 tracing::error!("worker {worker_id}: failed to build compio runtime: {e}");
598 shutdown.report_bind_failure(io::Error::other(format!(
601 "worker {worker_id}: failed to build compio runtime: {e}"
602 )));
603 return;
604 }
605 };
606
607 rt.block_on(async move {
608 let listener = match bind_reuseport_compio(addr, cfg.backlog) {
609 Ok(l) => {
610 shutdown.report_bind_success();
611 l
612 }
613 Err(e) => {
614 tracing::error!("worker {worker_id}: bind failed: {e}");
615 shutdown.report_bind_failure(e);
616 return;
617 }
618 };
619 tracing::debug!("tako-pt-compio worker {worker_id} listening on {addr}");
620
621 let cancel = shutdown.inner.clone();
622 let mut backoff = compio_accept_backoff();
623 let inflight = Arc::new(AtomicUsize::new(0));
624 let drain_notify = Arc::new(Notify::new());
625
626 loop {
627 let accept_fut = listener.accept();
628 let cancel_fut = cancel.cancelled();
629 tokio::pin!(accept_fut, cancel_fut);
630 let accept = futures_util::future::select(accept_fut, cancel_fut).await;
631 let (stream, peer) = match accept {
632 futures_util::future::Either::Left((Ok(v), _)) => {
633 backoff = compio_accept_backoff();
634 v
635 }
636 futures_util::future::Either::Left((Err(e), _)) => {
637 let delay = backoff;
638 tracing::warn!("worker {worker_id}: accept failed: {e}; backing off {delay:?}");
639 compio::time::sleep(delay).await;
640 backoff = std::cmp::min(backoff * 2, Duration::from_secs(1));
641 continue;
642 }
643 futures_util::future::Either::Right(_) => {
644 tracing::info!("worker {worker_id}: shutdown signalled, draining");
645 break;
646 }
647 };
648 if let Err(e) = stream.set_nodelay(true) {
652 tracing::debug!("worker {worker_id}: set_nodelay failed for {peer}: {e}");
653 }
654 let io = HyperStream::new(stream);
655 let guard = PtConnGuard::new(inflight.clone(), drain_notify.clone());
658
659 compio::runtime::spawn(async move {
660 let _guard = guard;
663 let svc = service_fn(move |mut req| async move {
664 req.extensions_mut().insert(peer);
671 req.extensions_mut().insert(ConnInfo::tcp(peer));
672 let resp = router
673 .dispatch(req.map(tako_rs_core::body::TakoBody::new))
674 .await;
675 Ok::<_, Infallible>(resp)
676 });
677
678 let mut http = http1::Builder::new();
679 http.keep_alive(true);
680 if let Err(err) = http.serve_connection(io, svc).with_upgrades().await {
681 if err.is_incomplete_message() {
682 tracing::debug!("worker {worker_id}: client disconnected mid-message: {err}");
683 } else {
684 tracing::error!("worker {worker_id}: connection error: {err}");
685 }
686 }
687 })
688 .detach();
689 }
690
691 let drain_deadline = std::time::Instant::now() + cfg.drain_timeout;
698 while inflight.load(Ordering::SeqCst) > 0 {
699 let now = std::time::Instant::now();
700 if now >= drain_deadline {
701 tracing::warn!(
702 worker_id,
703 drain_timeout = ?cfg.drain_timeout,
704 still_inflight = inflight.load(Ordering::SeqCst),
705 "drain timeout exceeded; remaining connections will be aborted"
706 );
707 break;
708 }
709 let remaining = drain_deadline - now;
710 let wait = drain_notify.notified();
711 let sleep = compio::time::sleep(remaining);
712 let wait = std::pin::pin!(wait);
713 let sleep = std::pin::pin!(sleep);
714 if let futures_util::future::Either::Right(_) =
715 futures_util::future::select(wait, sleep).await
716 {
717 tracing::warn!(
718 worker_id,
719 drain_timeout = ?cfg.drain_timeout,
720 still_inflight = inflight.load(Ordering::SeqCst),
721 "drain timeout exceeded; remaining connections will be aborted"
722 );
723 break;
724 }
725 }
726 });
727}