Skip to main content

tork_core/
server.rs

1//! The HTTP server: a Hyper accept loop with graceful shutdown.
2
3use std::convert::Infallible;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use hyper::body::Incoming;
11use hyper::service::Service;
12use hyper::{Request, Response};
13use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
14use hyper_util::server::conn::auto;
15use hyper_util::server::graceful::GracefulShutdown;
16use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
17use tokio::net::TcpListener;
18
19use crate::app::AppInner;
20use crate::body::{box_body, RespBody};
21use crate::constants::GRACEFUL_SHUTDOWN_TIMEOUT;
22use crate::extract::RequestPeerAddr;
23
24/// Maximum time allowed for a TLS handshake to complete before the pending
25/// connection is dropped, so a stalled client cannot hold a slot.
26#[cfg(feature = "tls")]
27const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
28
29/// HTTP/2 connection tuning applied to every served connection.
30///
31/// Each unset field keeps hyper's default. Configure with
32/// [`App::http2`](crate::App::http2).
33#[derive(Clone, Default)]
34pub struct Http2Config {
35    max_concurrent_streams: Option<u32>,
36    keep_alive_interval: Option<Duration>,
37    keep_alive_timeout: Option<Duration>,
38    initial_stream_window_size: Option<u32>,
39    initial_connection_window_size: Option<u32>,
40    max_frame_size: Option<u32>,
41    max_header_list_size: Option<u32>,
42}
43
44impl Http2Config {
45    /// Creates an empty config (every limit at hyper's default).
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Caps the number of concurrent streams a peer may open on one connection.
51    pub fn max_concurrent_streams(mut self, max: u32) -> Self {
52        self.max_concurrent_streams = Some(max);
53        self
54    }
55
56    /// Sends an HTTP/2 PING on an idle connection at this interval (keep-alive).
57    pub fn keep_alive_interval(mut self, interval: Duration) -> Self {
58        self.keep_alive_interval = Some(interval);
59        self
60    }
61
62    /// Closes the connection if a keep-alive PING is not answered within this time.
63    pub fn keep_alive_timeout(mut self, timeout: Duration) -> Self {
64        self.keep_alive_timeout = Some(timeout);
65        self
66    }
67
68    /// Sets the initial per-stream flow-control window, in bytes.
69    pub fn initial_stream_window_size(mut self, bytes: u32) -> Self {
70        self.initial_stream_window_size = Some(bytes);
71        self
72    }
73
74    /// Sets the initial connection-level flow-control window, in bytes.
75    pub fn initial_connection_window_size(mut self, bytes: u32) -> Self {
76        self.initial_connection_window_size = Some(bytes);
77        self
78    }
79
80    /// Sets the largest frame payload the server will accept, in bytes.
81    pub fn max_frame_size(mut self, bytes: u32) -> Self {
82        self.max_frame_size = Some(bytes);
83        self
84    }
85
86    /// Sets the maximum size of the decoded request header block, in bytes.
87    pub fn max_header_list_size(mut self, bytes: u32) -> Self {
88        self.max_header_list_size = Some(bytes);
89        self
90    }
91}
92
93/// HTTP/1 connection tuning applied to every served connection.
94///
95/// Each unset field keeps hyper's default. Configure with
96/// [`App::http1`](crate::App::http1).
97#[derive(Clone, Default)]
98pub struct Http1Config {
99    keep_alive: Option<bool>,
100    max_headers: Option<usize>,
101}
102
103impl Http1Config {
104    /// Creates an empty config (every setting at hyper's default).
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Enables or disables HTTP/1 keep-alive (persistent connections).
110    pub fn keep_alive(mut self, enabled: bool) -> Self {
111        self.keep_alive = Some(enabled);
112        self
113    }
114
115    /// Sets the maximum number of request headers accepted.
116    pub fn max_headers(mut self, max: usize) -> Self {
117        self.max_headers = Some(max);
118        self
119    }
120}
121
122/// Applies the app's HTTP/1 + HTTP/2 tuning onto the connection builder.
123fn configure_builder(builder: &mut auto::Builder<TokioExecutor>, app: &AppInner) {
124    {
125        let mut h1 = builder.http1();
126        h1.timer(TokioTimer::new());
127        if let Some(timeout) = app.header_read_timeout() {
128            h1.header_read_timeout(timeout);
129        }
130        if let Some(config) = app.http1_config() {
131            if let Some(enabled) = config.keep_alive {
132                h1.keep_alive(enabled);
133            }
134            if let Some(max) = config.max_headers {
135                h1.max_headers(max);
136            }
137        }
138    }
139    {
140        let mut h2 = builder.http2();
141        h2.timer(TokioTimer::new());
142        if let Some(config) = app.http2_config() {
143            if let Some(max) = config.max_concurrent_streams {
144                h2.max_concurrent_streams(max);
145            }
146            if let Some(interval) = config.keep_alive_interval {
147                h2.keep_alive_interval(interval);
148            }
149            if let Some(timeout) = config.keep_alive_timeout {
150                h2.keep_alive_timeout(timeout);
151            }
152            if let Some(bytes) = config.initial_stream_window_size {
153                h2.initial_stream_window_size(bytes);
154            }
155            if let Some(bytes) = config.initial_connection_window_size {
156                h2.initial_connection_window_size(bytes);
157            }
158            if let Some(bytes) = config.max_frame_size {
159                h2.max_frame_size(bytes);
160            }
161            if let Some(bytes) = config.max_header_list_size {
162                h2.max_header_list_size(bytes);
163            }
164        }
165    }
166}
167
168/// A Hyper [`Service`] that hands each request to the application.
169///
170/// The service error type is [`Infallible`]: application errors are rendered
171/// into responses by [`AppInner::dispatch`], so a failing request never tears
172/// down the underlying connection.
173#[derive(Clone)]
174pub struct TorkService {
175    app: Arc<AppInner>,
176    peer_addr: Option<std::net::SocketAddr>,
177}
178
179impl TorkService {
180    /// Creates a service backed by the given application core.
181    pub fn new(app: Arc<AppInner>) -> Self {
182        Self {
183            app,
184            peer_addr: None,
185        }
186    }
187
188    pub(crate) fn with_peer_addr(app: Arc<AppInner>, peer_addr: std::net::SocketAddr) -> Self {
189        Self {
190            app,
191            peer_addr: Some(peer_addr),
192        }
193    }
194}
195
196impl Service<Request<Incoming>> for TorkService {
197    type Response = Response<RespBody>;
198    type Error = Infallible;
199    type Future =
200        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
201
202    fn call(&self, request: Request<Incoming>) -> Self::Future {
203        let app = self.app.clone();
204        let peer_addr = self.peer_addr;
205        Box::pin(async move {
206            // Erase the connection body into the runtime's request body type.
207            let (mut parts, incoming) = request.into_parts();
208            if let Some(peer_addr) = peer_addr {
209                parts.extensions.insert(RequestPeerAddr(peer_addr));
210            }
211            let request = Request::from_parts(parts, box_body(incoming));
212            Ok(app.handle(request).await)
213        })
214    }
215}
216
217/// A listening socket the accept loop can drive, abstracting over TCP and Unix.
218///
219/// `accept_io` yields a connection stream and an optional peer address (Unix-domain
220/// connections have no `SocketAddr`).
221pub(crate) trait IncomingListener {
222    type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
223
224    fn accept_io(
225        &self,
226    ) -> impl Future<Output = std::io::Result<(Self::Io, Option<std::net::SocketAddr>)>> + Send;
227}
228
229impl IncomingListener for TcpListener {
230    type Io = tokio::net::TcpStream;
231
232    async fn accept_io(&self) -> std::io::Result<(Self::Io, Option<std::net::SocketAddr>)> {
233        let (stream, peer) = self.accept().await?;
234        Ok((stream, Some(peer)))
235    }
236}
237
238#[cfg(unix)]
239impl IncomingListener for tokio::net::UnixListener {
240    type Io = tokio::net::UnixStream;
241
242    async fn accept_io(&self) -> std::io::Result<(Self::Io, Option<std::net::SocketAddr>)> {
243        let (stream, _addr) = self.accept().await?;
244        Ok((stream, None))
245    }
246}
247
248/// Binds a TCP listener for `addr`, optionally setting `SO_REUSEPORT`.
249///
250/// Without `reuse_port` this is a plain `TcpListener::bind`. With it, the socket is
251/// built by hand so `SO_REUSEPORT` (Unix) can be set before binding, letting several
252/// processes share the address.
253pub(crate) async fn bind_tcp_listener(addr: &str, reuse_port: bool) -> std::io::Result<TcpListener> {
254    if !reuse_port {
255        return TcpListener::bind(addr).await;
256    }
257
258    let resolved = tokio::net::lookup_host(addr).await?.next().ok_or_else(|| {
259        std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, "no address resolved")
260    })?;
261    let domain = if resolved.is_ipv6() {
262        socket2::Domain::IPV6
263    } else {
264        socket2::Domain::IPV4
265    };
266    let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
267    socket.set_reuse_address(true)?;
268    #[cfg(unix)]
269    socket.set_reuse_port(true)?;
270    socket.set_nonblocking(true)?;
271    socket.bind(&resolved.into())?;
272    socket.listen(1024)?;
273    TcpListener::from_std(socket.into())
274}
275
276/// Runs the accept loop on `listener`, stopping when `shutdown` resolves.
277///
278/// The application lifecycle around this loop (startup, bind, readiness, drain,
279/// shutdown) lives in [`App::serve`](crate::App::serve).
280pub(crate) async fn run_with_shutdown<S, L>(app: Arc<AppInner>, listener: L, shutdown: S)
281where
282    S: Future<Output = ()>,
283    L: IncomingListener,
284{
285    let mut builder = auto::Builder::new(TokioExecutor::new());
286    // Wire the per-connection timers and the configured HTTP/1 + HTTP/2 tuning
287    // (including the slowloris-bounding header-read timeout) onto the builder.
288    configure_builder(&mut builder, &app);
289    let graceful = GracefulShutdown::new();
290    let mut shutdown = std::pin::pin!(shutdown);
291
292    // With TLS, terminate each connection through the rustls acceptor before
293    // handing it to hyper; otherwise serve the plain TCP stream directly.
294    #[cfg(feature = "tls")]
295    if let Some(acceptor) = app.tls_acceptor().cloned() {
296        accept_tls(&app, &listener, &builder, &graceful, &mut shutdown, acceptor).await;
297    } else {
298        accept_plain(&app, &listener, &builder, &graceful, &mut shutdown).await;
299    }
300    #[cfg(not(feature = "tls"))]
301    accept_plain(&app, &listener, &builder, &graceful, &mut shutdown).await;
302
303    // Tell in-flight WebSocket connections to close cleanly. They run in spawned
304    // upgrade tasks that `GracefulShutdown` does not track, so without this they
305    // would simply be dropped when the runtime stops.
306    app.begin_ws_shutdown();
307
308    // Stop accepting, then drain in-flight connections within the timeout.
309    drain_with_timeout(
310        graceful.shutdown(),
311        tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT),
312    )
313    .await;
314}
315
316/// Accept loop for plain (non-TLS) connections.
317async fn accept_plain<S, L>(
318    app: &Arc<AppInner>,
319    listener: &L,
320    builder: &auto::Builder<TokioExecutor>,
321    graceful: &GracefulShutdown,
322    shutdown: &mut Pin<&mut S>,
323) where
324    S: Future<Output = ()>,
325    L: IncomingListener,
326{
327    loop {
328        tokio::select! {
329            accepted = listener.accept_io() => {
330                let _ = handle_accepted_connection(app.clone(), builder, graceful, accepted).await;
331            }
332            _ = shutdown.as_mut() => break,
333        }
334    }
335}
336
337/// Accept loop for TLS connections.
338///
339/// Each accepted socket is handed to a spawned task that performs the rustls
340/// handshake under [`TLS_HANDSHAKE_TIMEOUT`], so a slow handshake never blocks the
341/// accept loop. Completed TLS streams come back over a channel and are served (and
342/// tracked by `GracefulShutdown`) exactly like a plain connection.
343#[cfg(feature = "tls")]
344async fn accept_tls<S, L>(
345    app: &Arc<AppInner>,
346    listener: &L,
347    builder: &auto::Builder<TokioExecutor>,
348    graceful: &GracefulShutdown,
349    shutdown: &mut Pin<&mut S>,
350    acceptor: tokio_rustls::TlsAcceptor,
351) where
352    S: Future<Output = ()>,
353    L: IncomingListener,
354{
355    type Handshaked<Io> = (tokio_rustls::server::TlsStream<Io>, Option<std::net::SocketAddr>);
356    let (handshake_tx, mut handshake_rx) =
357        tokio::sync::mpsc::channel::<Handshaked<L::Io>>(256);
358
359    loop {
360        tokio::select! {
361            accepted = listener.accept_io() => {
362                if let Ok((stream, peer)) = accepted {
363                    let acceptor = acceptor.clone();
364                    let handshake_tx = handshake_tx.clone();
365                    tokio::spawn(async move {
366                        if let Ok(Ok(tls)) =
367                            tokio::time::timeout(TLS_HANDSHAKE_TIMEOUT, acceptor.accept(stream)).await
368                        {
369                            let _ = handshake_tx.send((tls, peer)).await;
370                        }
371                    });
372                }
373            }
374            Some((tls, peer)) = handshake_rx.recv() => {
375                let _ = handle_accepted_connection(app.clone(), builder, graceful, Ok((tls, peer))).await;
376            }
377            _ = shutdown.as_mut() => break,
378        }
379    }
380}
381
382/// Resolves when the process receives an interrupt or termination signal.
383pub(crate) async fn shutdown_signal() {
384    let interrupt = async {
385        let _ = tokio::signal::ctrl_c().await;
386    };
387
388    #[cfg(unix)]
389    let terminate = async {
390        use tokio::signal::unix::{signal, SignalKind};
391        if let Ok(mut stream) = signal(SignalKind::terminate()) {
392            stream.recv().await;
393        }
394    };
395
396    #[cfg(not(unix))]
397    let terminate = std::future::pending::<()>();
398
399    shutdown_signal_with(interrupt, terminate).await;
400}
401
402async fn handle_accepted_connection<S>(
403    app: Arc<AppInner>,
404    builder: &auto::Builder<TokioExecutor>,
405    graceful: &GracefulShutdown,
406    accepted: std::io::Result<(S, Option<std::net::SocketAddr>)>,
407) -> bool
408where
409    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
410{
411    let (stream, peer) = match accepted {
412        Ok(pair) => pair,
413        // Transient accept errors (for example, fd exhaustion) should
414        // not bring the server down; skip and keep accepting.
415        Err(_) => return false,
416    };
417
418    // Wrap the stream in an idle-timeout guard when one is configured, so a
419    // connection with no read/write activity is dropped instead of held open.
420    match app.idle_timeout() {
421        Some(idle) => serve_io(app, builder, graceful, IdleTimeoutStream::new(stream, idle), peer),
422        None => serve_io(app, builder, graceful, stream, peer),
423    }
424
425    true
426}
427
428/// Serves one connection: hands the (possibly wrapped) stream to hyper, tracks it
429/// for graceful drain, and drives it on a spawned task.
430fn serve_io<IO>(
431    app: Arc<AppInner>,
432    builder: &auto::Builder<TokioExecutor>,
433    graceful: &GracefulShutdown,
434    stream: IO,
435    peer: Option<std::net::SocketAddr>,
436) where
437    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
438{
439    let io = TokioIo::new(stream);
440    let service = match peer {
441        Some(peer) => TorkService::with_peer_addr(app, peer),
442        None => TorkService::new(app),
443    };
444    let connection = builder.serve_connection_with_upgrades(io, service);
445    let watched = graceful.watch(connection.into_owned());
446
447    tokio::spawn(async move {
448        // Connection-level errors are already terminal for that
449        // connection; nothing actionable remains here.
450        let _ = watched.await;
451    });
452}
453
454/// A stream wrapper that ends the connection after `idle` with no read or write
455/// activity. The struct is `Unpin` (the timer is boxed, the inner stream is
456/// `Unpin`), so the poll methods project without any `unsafe`.
457struct IdleTimeoutStream<S> {
458    inner: S,
459    timer: Pin<Box<tokio::time::Sleep>>,
460    idle: Duration,
461}
462
463impl<S> IdleTimeoutStream<S> {
464    fn new(inner: S, idle: Duration) -> Self {
465        Self {
466            inner,
467            timer: Box::pin(tokio::time::sleep(idle)),
468            idle,
469        }
470    }
471
472    /// Pushes the idle deadline forward after activity.
473    fn touch(&mut self) {
474        self.timer
475            .as_mut()
476            .reset(tokio::time::Instant::now() + self.idle);
477    }
478
479    /// Returns `true` once the idle deadline has passed (and registers a wake-up).
480    fn idle_expired(&mut self, cx: &mut Context<'_>) -> bool {
481        self.timer.as_mut().poll(cx).is_ready()
482    }
483}
484
485impl<S: AsyncRead + Unpin> AsyncRead for IdleTimeoutStream<S> {
486    fn poll_read(
487        self: Pin<&mut Self>,
488        cx: &mut Context<'_>,
489        buf: &mut ReadBuf<'_>,
490    ) -> Poll<std::io::Result<()>> {
491        let this = self.get_mut();
492        if this.idle_expired(cx) {
493            return Poll::Ready(Err(std::io::Error::new(
494                std::io::ErrorKind::TimedOut,
495                "connection idle timeout",
496            )));
497        }
498        let before = buf.filled().len();
499        match Pin::new(&mut this.inner).poll_read(cx, buf) {
500            Poll::Ready(Ok(())) => {
501                if buf.filled().len() != before {
502                    this.touch();
503                }
504                Poll::Ready(Ok(()))
505            }
506            other => other,
507        }
508    }
509}
510
511impl<S: AsyncWrite + Unpin> AsyncWrite for IdleTimeoutStream<S> {
512    fn poll_write(
513        self: Pin<&mut Self>,
514        cx: &mut Context<'_>,
515        buf: &[u8],
516    ) -> Poll<std::io::Result<usize>> {
517        let this = self.get_mut();
518        match Pin::new(&mut this.inner).poll_write(cx, buf) {
519            Poll::Ready(Ok(written)) => {
520                this.touch();
521                Poll::Ready(Ok(written))
522            }
523            other => other,
524        }
525    }
526
527    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
528        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
529    }
530
531    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
532        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
533    }
534}
535
536async fn drain_with_timeout<F, T>(shutdown: F, timeout: T)
537where
538    F: Future<Output = ()>,
539    T: Future<Output = ()>,
540{
541    tokio::select! {
542        _ = shutdown => {}
543        _ = timeout => {}
544    }
545}
546
547async fn shutdown_signal_with<I, T>(interrupt: I, terminate: T)
548where
549    I: Future<Output = ()>,
550    T: Future<Output = ()>,
551{
552    tokio::select! {
553        _ = interrupt => {}
554        _ = terminate => {}
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::app::App;
562    use crate::extract::RequestContext;
563    use crate::response::Response as TorkResponse;
564    use crate::router::{BoxFuture, HandlerFn, Route, Router};
565    use crate::{json_response, Method, StatusCode};
566
567    use std::future;
568    use std::sync::Arc;
569    use tokio::io::{AsyncReadExt, AsyncWriteExt};
570    use tokio::net::TcpStream;
571    use tokio::sync::oneshot;
572
573    #[tokio::test]
574    async fn serves_a_request_over_tcp() {
575        let handler: HandlerFn = Arc::new(
576            |_ctx: RequestContext| -> BoxFuture<'static, crate::Result<TorkResponse>> {
577                Box::pin(async {
578                    Ok(json_response(
579                        StatusCode::OK,
580                        &serde_json::json!({ "pong": true }),
581                    ))
582                })
583            },
584        );
585        let router = Router::new().route(Route::new(Method::GET, "/ping", handler));
586        let app = Arc::new(App::new().include_router(router).build().unwrap());
587
588        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
589        let addr = listener.local_addr().unwrap();
590
591        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
592        let server = tokio::spawn(run_with_shutdown(app, listener, async move {
593            let _ = shutdown_rx.await;
594        }));
595
596        let mut stream = TcpStream::connect(addr).await.unwrap();
597        stream
598            .write_all(b"GET /ping HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
599            .await
600            .unwrap();
601
602        let mut response = String::new();
603        stream.read_to_string(&mut response).await.unwrap();
604
605        assert!(
606            response.contains("200 OK"),
607            "unexpected response: {response}"
608        );
609        assert!(
610            response.contains("\"pong\":true"),
611            "unexpected body: {response}"
612        );
613
614        let _ = shutdown_tx.send(());
615        let _ = server.await;
616    }
617
618    #[tokio::test]
619    async fn helper_paths_cover_accept_errors_shutdown_and_signals() {
620        let builder = auto::Builder::new(TokioExecutor::new());
621        let app = Arc::new(App::new().build().unwrap());
622        let graceful = GracefulShutdown::new();
623
624        assert!(
625            !handle_accepted_connection::<tokio::io::DuplexStream>(
626                app.clone(),
627                &builder,
628                &graceful,
629                Err(std::io::Error::other("accept failed"))
630            )
631            .await
632        );
633
634        let (stream, _peer) = tokio::io::duplex(16);
635        assert!(
636            handle_accepted_connection(
637                app,
638                &builder,
639                &graceful,
640                Ok((stream, Some("127.0.0.1:0".parse().unwrap())))
641            )
642            .await
643        );
644
645        drain_with_timeout(future::ready(()), future::pending::<()>()).await;
646        drain_with_timeout(future::pending::<()>(), future::ready(())).await;
647
648        shutdown_signal_with(future::ready(()), future::pending::<()>()).await;
649        shutdown_signal_with(future::pending::<()>(), future::ready(())).await;
650    }
651
652    #[tokio::test]
653    async fn tork_service_new_returns_cloneable_service() {
654        let app = Arc::new(App::new().build().unwrap());
655        let service = TorkService::new(app);
656        // Verify the service is Clone (derived).
657        let _cloned = service.clone();
658    }
659
660    #[tokio::test]
661    async fn run_with_shutdown_breaks_when_shutdown_resolves_first() {
662        // Build a minimal app, bind to an ephemeral port, and run the loop
663        // with a shutdown future that fires immediately — no connection is
664        // ever accepted, exercising the `_ = &mut shutdown => break` branch.
665        let app = Arc::new(App::new().build().unwrap());
666        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
667        run_with_shutdown(app, listener, future::ready(())).await;
668    }
669
670    #[cfg(unix)]
671    #[tokio::test]
672    async fn reuse_port_allows_two_listeners_on_the_same_port() {
673        // With SO_REUSEPORT, a second listener can bind the port the first holds.
674        let first = bind_tcp_listener("127.0.0.1:0", true).await.unwrap();
675        let addr = first.local_addr().unwrap();
676        let second = bind_tcp_listener(&addr.to_string(), true).await.unwrap();
677        assert_eq!(second.local_addr().unwrap().port(), addr.port());
678    }
679}