1use std::convert::Infallible;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use hyper::body::Incoming;
11use hyper::service::Service;
12use hyper::{Request, Response};
13use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
14use hyper_util::server::conn::auto;
15use hyper_util::server::graceful::GracefulShutdown;
16use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
17use tokio::net::TcpListener;
18
19use crate::app::AppInner;
20use crate::body::{box_body, RespBody};
21use crate::constants::GRACEFUL_SHUTDOWN_TIMEOUT;
22use crate::extract::RequestPeerAddr;
23
24#[cfg(feature = "tls")]
27const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
28
29#[derive(Clone, Default)]
34pub struct Http2Config {
35 max_concurrent_streams: Option<u32>,
36 keep_alive_interval: Option<Duration>,
37 keep_alive_timeout: Option<Duration>,
38 initial_stream_window_size: Option<u32>,
39 initial_connection_window_size: Option<u32>,
40 max_frame_size: Option<u32>,
41 max_header_list_size: Option<u32>,
42}
43
44impl Http2Config {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn max_concurrent_streams(mut self, max: u32) -> Self {
52 self.max_concurrent_streams = Some(max);
53 self
54 }
55
56 pub fn keep_alive_interval(mut self, interval: Duration) -> Self {
58 self.keep_alive_interval = Some(interval);
59 self
60 }
61
62 pub fn keep_alive_timeout(mut self, timeout: Duration) -> Self {
64 self.keep_alive_timeout = Some(timeout);
65 self
66 }
67
68 pub fn initial_stream_window_size(mut self, bytes: u32) -> Self {
70 self.initial_stream_window_size = Some(bytes);
71 self
72 }
73
74 pub fn initial_connection_window_size(mut self, bytes: u32) -> Self {
76 self.initial_connection_window_size = Some(bytes);
77 self
78 }
79
80 pub fn max_frame_size(mut self, bytes: u32) -> Self {
82 self.max_frame_size = Some(bytes);
83 self
84 }
85
86 pub fn max_header_list_size(mut self, bytes: u32) -> Self {
88 self.max_header_list_size = Some(bytes);
89 self
90 }
91}
92
93#[derive(Clone, Default)]
98pub struct Http1Config {
99 keep_alive: Option<bool>,
100 max_headers: Option<usize>,
101}
102
103impl Http1Config {
104 pub fn new() -> Self {
106 Self::default()
107 }
108
109 pub fn keep_alive(mut self, enabled: bool) -> Self {
111 self.keep_alive = Some(enabled);
112 self
113 }
114
115 pub fn max_headers(mut self, max: usize) -> Self {
117 self.max_headers = Some(max);
118 self
119 }
120}
121
122fn configure_builder(builder: &mut auto::Builder<TokioExecutor>, app: &AppInner) {
124 {
125 let mut h1 = builder.http1();
126 h1.timer(TokioTimer::new());
127 if let Some(timeout) = app.header_read_timeout() {
128 h1.header_read_timeout(timeout);
129 }
130 if let Some(config) = app.http1_config() {
131 if let Some(enabled) = config.keep_alive {
132 h1.keep_alive(enabled);
133 }
134 if let Some(max) = config.max_headers {
135 h1.max_headers(max);
136 }
137 }
138 }
139 {
140 let mut h2 = builder.http2();
141 h2.timer(TokioTimer::new());
142 if let Some(config) = app.http2_config() {
143 if let Some(max) = config.max_concurrent_streams {
144 h2.max_concurrent_streams(max);
145 }
146 if let Some(interval) = config.keep_alive_interval {
147 h2.keep_alive_interval(interval);
148 }
149 if let Some(timeout) = config.keep_alive_timeout {
150 h2.keep_alive_timeout(timeout);
151 }
152 if let Some(bytes) = config.initial_stream_window_size {
153 h2.initial_stream_window_size(bytes);
154 }
155 if let Some(bytes) = config.initial_connection_window_size {
156 h2.initial_connection_window_size(bytes);
157 }
158 if let Some(bytes) = config.max_frame_size {
159 h2.max_frame_size(bytes);
160 }
161 if let Some(bytes) = config.max_header_list_size {
162 h2.max_header_list_size(bytes);
163 }
164 }
165 }
166}
167
168#[derive(Clone)]
174pub struct TorkService {
175 app: Arc<AppInner>,
176 peer_addr: Option<std::net::SocketAddr>,
177}
178
179impl TorkService {
180 pub fn new(app: Arc<AppInner>) -> Self {
182 Self {
183 app,
184 peer_addr: None,
185 }
186 }
187
188 pub(crate) fn with_peer_addr(app: Arc<AppInner>, peer_addr: std::net::SocketAddr) -> Self {
189 Self {
190 app,
191 peer_addr: Some(peer_addr),
192 }
193 }
194}
195
196impl Service<Request<Incoming>> for TorkService {
197 type Response = Response<RespBody>;
198 type Error = Infallible;
199 type Future =
200 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
201
202 fn call(&self, request: Request<Incoming>) -> Self::Future {
203 let app = self.app.clone();
204 let peer_addr = self.peer_addr;
205 Box::pin(async move {
206 let (mut parts, incoming) = request.into_parts();
208 if let Some(peer_addr) = peer_addr {
209 parts.extensions.insert(RequestPeerAddr(peer_addr));
210 }
211 let request = Request::from_parts(parts, box_body(incoming));
212 Ok(app.handle(request).await)
213 })
214 }
215}
216
217pub(crate) trait IncomingListener {
222 type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
223
224 fn accept_io(
225 &self,
226 ) -> impl Future<Output = std::io::Result<(Self::Io, Option<std::net::SocketAddr>)>> + Send;
227}
228
229impl IncomingListener for TcpListener {
230 type Io = tokio::net::TcpStream;
231
232 async fn accept_io(&self) -> std::io::Result<(Self::Io, Option<std::net::SocketAddr>)> {
233 let (stream, peer) = self.accept().await?;
234 Ok((stream, Some(peer)))
235 }
236}
237
238#[cfg(unix)]
239impl IncomingListener for tokio::net::UnixListener {
240 type Io = tokio::net::UnixStream;
241
242 async fn accept_io(&self) -> std::io::Result<(Self::Io, Option<std::net::SocketAddr>)> {
243 let (stream, _addr) = self.accept().await?;
244 Ok((stream, None))
245 }
246}
247
248pub(crate) async fn bind_tcp_listener(addr: &str, reuse_port: bool) -> std::io::Result<TcpListener> {
254 if !reuse_port {
255 return TcpListener::bind(addr).await;
256 }
257
258 let resolved = tokio::net::lookup_host(addr).await?.next().ok_or_else(|| {
259 std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, "no address resolved")
260 })?;
261 let domain = if resolved.is_ipv6() {
262 socket2::Domain::IPV6
263 } else {
264 socket2::Domain::IPV4
265 };
266 let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
267 socket.set_reuse_address(true)?;
268 #[cfg(unix)]
269 socket.set_reuse_port(true)?;
270 socket.set_nonblocking(true)?;
271 socket.bind(&resolved.into())?;
272 socket.listen(1024)?;
273 TcpListener::from_std(socket.into())
274}
275
276pub(crate) async fn run_with_shutdown<S, L>(app: Arc<AppInner>, listener: L, shutdown: S)
281where
282 S: Future<Output = ()>,
283 L: IncomingListener,
284{
285 let mut builder = auto::Builder::new(TokioExecutor::new());
286 configure_builder(&mut builder, &app);
289 let graceful = GracefulShutdown::new();
290 let mut shutdown = std::pin::pin!(shutdown);
291
292 #[cfg(feature = "tls")]
295 if let Some(acceptor) = app.tls_acceptor().cloned() {
296 accept_tls(&app, &listener, &builder, &graceful, &mut shutdown, acceptor).await;
297 } else {
298 accept_plain(&app, &listener, &builder, &graceful, &mut shutdown).await;
299 }
300 #[cfg(not(feature = "tls"))]
301 accept_plain(&app, &listener, &builder, &graceful, &mut shutdown).await;
302
303 app.begin_ws_shutdown();
307
308 drain_with_timeout(
310 graceful.shutdown(),
311 tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT),
312 )
313 .await;
314}
315
316async fn accept_plain<S, L>(
318 app: &Arc<AppInner>,
319 listener: &L,
320 builder: &auto::Builder<TokioExecutor>,
321 graceful: &GracefulShutdown,
322 shutdown: &mut Pin<&mut S>,
323) where
324 S: Future<Output = ()>,
325 L: IncomingListener,
326{
327 loop {
328 tokio::select! {
329 accepted = listener.accept_io() => {
330 let _ = handle_accepted_connection(app.clone(), builder, graceful, accepted).await;
331 }
332 _ = shutdown.as_mut() => break,
333 }
334 }
335}
336
337#[cfg(feature = "tls")]
344async fn accept_tls<S, L>(
345 app: &Arc<AppInner>,
346 listener: &L,
347 builder: &auto::Builder<TokioExecutor>,
348 graceful: &GracefulShutdown,
349 shutdown: &mut Pin<&mut S>,
350 acceptor: tokio_rustls::TlsAcceptor,
351) where
352 S: Future<Output = ()>,
353 L: IncomingListener,
354{
355 type Handshaked<Io> = (tokio_rustls::server::TlsStream<Io>, Option<std::net::SocketAddr>);
356 let (handshake_tx, mut handshake_rx) =
357 tokio::sync::mpsc::channel::<Handshaked<L::Io>>(256);
358
359 loop {
360 tokio::select! {
361 accepted = listener.accept_io() => {
362 if let Ok((stream, peer)) = accepted {
363 let acceptor = acceptor.clone();
364 let handshake_tx = handshake_tx.clone();
365 tokio::spawn(async move {
366 if let Ok(Ok(tls)) =
367 tokio::time::timeout(TLS_HANDSHAKE_TIMEOUT, acceptor.accept(stream)).await
368 {
369 let _ = handshake_tx.send((tls, peer)).await;
370 }
371 });
372 }
373 }
374 Some((tls, peer)) = handshake_rx.recv() => {
375 let _ = handle_accepted_connection(app.clone(), builder, graceful, Ok((tls, peer))).await;
376 }
377 _ = shutdown.as_mut() => break,
378 }
379 }
380}
381
382pub(crate) async fn shutdown_signal() {
384 let interrupt = async {
385 let _ = tokio::signal::ctrl_c().await;
386 };
387
388 #[cfg(unix)]
389 let terminate = async {
390 use tokio::signal::unix::{signal, SignalKind};
391 if let Ok(mut stream) = signal(SignalKind::terminate()) {
392 stream.recv().await;
393 }
394 };
395
396 #[cfg(not(unix))]
397 let terminate = std::future::pending::<()>();
398
399 shutdown_signal_with(interrupt, terminate).await;
400}
401
402async fn handle_accepted_connection<S>(
403 app: Arc<AppInner>,
404 builder: &auto::Builder<TokioExecutor>,
405 graceful: &GracefulShutdown,
406 accepted: std::io::Result<(S, Option<std::net::SocketAddr>)>,
407) -> bool
408where
409 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
410{
411 let (stream, peer) = match accepted {
412 Ok(pair) => pair,
413 Err(_) => return false,
416 };
417
418 match app.idle_timeout() {
421 Some(idle) => serve_io(app, builder, graceful, IdleTimeoutStream::new(stream, idle), peer),
422 None => serve_io(app, builder, graceful, stream, peer),
423 }
424
425 true
426}
427
428fn serve_io<IO>(
431 app: Arc<AppInner>,
432 builder: &auto::Builder<TokioExecutor>,
433 graceful: &GracefulShutdown,
434 stream: IO,
435 peer: Option<std::net::SocketAddr>,
436) where
437 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
438{
439 let io = TokioIo::new(stream);
440 let service = match peer {
441 Some(peer) => TorkService::with_peer_addr(app, peer),
442 None => TorkService::new(app),
443 };
444 let connection = builder.serve_connection_with_upgrades(io, service);
445 let watched = graceful.watch(connection.into_owned());
446
447 tokio::spawn(async move {
448 let _ = watched.await;
451 });
452}
453
454struct IdleTimeoutStream<S> {
458 inner: S,
459 timer: Pin<Box<tokio::time::Sleep>>,
460 idle: Duration,
461}
462
463impl<S> IdleTimeoutStream<S> {
464 fn new(inner: S, idle: Duration) -> Self {
465 Self {
466 inner,
467 timer: Box::pin(tokio::time::sleep(idle)),
468 idle,
469 }
470 }
471
472 fn touch(&mut self) {
474 self.timer
475 .as_mut()
476 .reset(tokio::time::Instant::now() + self.idle);
477 }
478
479 fn idle_expired(&mut self, cx: &mut Context<'_>) -> bool {
481 self.timer.as_mut().poll(cx).is_ready()
482 }
483}
484
485impl<S: AsyncRead + Unpin> AsyncRead for IdleTimeoutStream<S> {
486 fn poll_read(
487 self: Pin<&mut Self>,
488 cx: &mut Context<'_>,
489 buf: &mut ReadBuf<'_>,
490 ) -> Poll<std::io::Result<()>> {
491 let this = self.get_mut();
492 if this.idle_expired(cx) {
493 return Poll::Ready(Err(std::io::Error::new(
494 std::io::ErrorKind::TimedOut,
495 "connection idle timeout",
496 )));
497 }
498 let before = buf.filled().len();
499 match Pin::new(&mut this.inner).poll_read(cx, buf) {
500 Poll::Ready(Ok(())) => {
501 if buf.filled().len() != before {
502 this.touch();
503 }
504 Poll::Ready(Ok(()))
505 }
506 other => other,
507 }
508 }
509}
510
511impl<S: AsyncWrite + Unpin> AsyncWrite for IdleTimeoutStream<S> {
512 fn poll_write(
513 self: Pin<&mut Self>,
514 cx: &mut Context<'_>,
515 buf: &[u8],
516 ) -> Poll<std::io::Result<usize>> {
517 let this = self.get_mut();
518 match Pin::new(&mut this.inner).poll_write(cx, buf) {
519 Poll::Ready(Ok(written)) => {
520 this.touch();
521 Poll::Ready(Ok(written))
522 }
523 other => other,
524 }
525 }
526
527 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
528 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
529 }
530
531 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
532 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
533 }
534}
535
536async fn drain_with_timeout<F, T>(shutdown: F, timeout: T)
537where
538 F: Future<Output = ()>,
539 T: Future<Output = ()>,
540{
541 tokio::select! {
542 _ = shutdown => {}
543 _ = timeout => {}
544 }
545}
546
547async fn shutdown_signal_with<I, T>(interrupt: I, terminate: T)
548where
549 I: Future<Output = ()>,
550 T: Future<Output = ()>,
551{
552 tokio::select! {
553 _ = interrupt => {}
554 _ = terminate => {}
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::app::App;
562 use crate::extract::RequestContext;
563 use crate::response::Response as TorkResponse;
564 use crate::router::{BoxFuture, HandlerFn, Route, Router};
565 use crate::{json_response, Method, StatusCode};
566
567 use std::future;
568 use std::sync::Arc;
569 use tokio::io::{AsyncReadExt, AsyncWriteExt};
570 use tokio::net::TcpStream;
571 use tokio::sync::oneshot;
572
573 #[tokio::test]
574 async fn serves_a_request_over_tcp() {
575 let handler: HandlerFn = Arc::new(
576 |_ctx: RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
577 Box::pin(async {
578 Ok(json_response(
579 StatusCode::OK,
580 &serde_json::json!({ "pong": true }),
581 ))
582 })
583 },
584 );
585 let router = Router::new().route(Route::new(Method::GET, "/ping", handler));
586 let app = Arc::new(App::new().include_router(router).build().unwrap());
587
588 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
589 let addr = listener.local_addr().unwrap();
590
591 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
592 let server = tokio::spawn(run_with_shutdown(app, listener, async move {
593 let _ = shutdown_rx.await;
594 }));
595
596 let mut stream = TcpStream::connect(addr).await.unwrap();
597 stream
598 .write_all(b"GET /ping HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
599 .await
600 .unwrap();
601
602 let mut response = String::new();
603 stream.read_to_string(&mut response).await.unwrap();
604
605 assert!(
606 response.contains("200 OK"),
607 "unexpected response: {response}"
608 );
609 assert!(
610 response.contains("\"pong\":true"),
611 "unexpected body: {response}"
612 );
613
614 let _ = shutdown_tx.send(());
615 let _ = server.await;
616 }
617
618 #[tokio::test]
619 async fn helper_paths_cover_accept_errors_shutdown_and_signals() {
620 let builder = auto::Builder::new(TokioExecutor::new());
621 let app = Arc::new(App::new().build().unwrap());
622 let graceful = GracefulShutdown::new();
623
624 assert!(
625 !handle_accepted_connection::<tokio::io::DuplexStream>(
626 app.clone(),
627 &builder,
628 &graceful,
629 Err(std::io::Error::other("accept failed"))
630 )
631 .await
632 );
633
634 let (stream, _peer) = tokio::io::duplex(16);
635 assert!(
636 handle_accepted_connection(
637 app,
638 &builder,
639 &graceful,
640 Ok((stream, Some("127.0.0.1:0".parse().unwrap())))
641 )
642 .await
643 );
644
645 drain_with_timeout(future::ready(()), future::pending::<()>()).await;
646 drain_with_timeout(future::pending::<()>(), future::ready(())).await;
647
648 shutdown_signal_with(future::ready(()), future::pending::<()>()).await;
649 shutdown_signal_with(future::pending::<()>(), future::ready(())).await;
650 }
651
652 #[tokio::test]
653 async fn tork_service_new_returns_cloneable_service() {
654 let app = Arc::new(App::new().build().unwrap());
655 let service = TorkService::new(app);
656 let _cloned = service.clone();
658 }
659
660 #[tokio::test]
661 async fn run_with_shutdown_breaks_when_shutdown_resolves_first() {
662 let app = Arc::new(App::new().build().unwrap());
666 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
667 run_with_shutdown(app, listener, future::ready(())).await;
668 }
669
670 #[cfg(unix)]
671 #[tokio::test]
672 async fn reuse_port_allows_two_listeners_on_the_same_port() {
673 let first = bind_tcp_listener("127.0.0.1:0", true).await.unwrap();
675 let addr = first.local_addr().unwrap();
676 let second = bind_tcp_listener(&addr.to_string(), true).await.unwrap();
677 assert_eq!(second.local_addr().unwrap().port(), addr.port());
678 }
679}