volo_http/server/
mod.rs

1//! Server implementation
2//!
3//! See [`Server`] for more details.
4
5use std::{
6    cell::RefCell,
7    convert::Infallible,
8    error::Error,
9    sync::{
10        Arc,
11        atomic::{AtomicUsize, Ordering},
12    },
13    time::Duration,
14};
15
16use futures::future::BoxFuture;
17use hyper_util::{
18    rt::{TokioExecutor, TokioIo},
19    server::conn::auto,
20};
21use metainfo::{METAINFO, MetaInfo};
22use motore::{
23    BoxError,
24    layer::{Identity, Layer, Stack},
25    service::Service,
26};
27use parking_lot::RwLock;
28use scopeguard::defer;
29use tokio::sync::Notify;
30use tracing::Instrument;
31#[cfg(feature = "__tls")]
32use volo::net::{conn::ConnStream, tls::ServerTlsConfig};
33use volo::{
34    context::Context,
35    net::{Address, MakeIncoming, conn::Conn, incoming::Incoming},
36};
37
38use self::span_provider::{DefaultProvider, SpanProvider};
39use crate::{
40    body::Body,
41    context::{ServerContext, server::Config},
42    request::Request,
43    response::Response,
44};
45
46pub mod extract;
47mod handler;
48pub mod layer;
49pub mod middleware;
50pub mod panic_handler;
51pub mod param;
52pub mod protocol;
53pub mod response;
54pub mod route;
55pub mod span_provider;
56#[cfg(test)]
57pub mod test_helpers;
58pub mod utils;
59
60pub use self::{
61    response::{IntoResponse, Redirect},
62    route::Router,
63};
64
65#[doc(hidden)]
66pub mod prelude {
67    #[cfg(feature = "__tls")]
68    pub use volo::net::tls::ServerTlsConfig;
69
70    pub use super::{Server, param::PathParams, route::Router};
71}
72
73/// High level HTTP server.
74///
75/// # Examples
76///
77/// ```no_run
78/// use std::net::SocketAddr;
79///
80/// use volo::net::Address;
81/// use volo_http::server::{
82///     Server,
83///     route::{Router, get},
84/// };
85///
86/// async fn index() -> &'static str {
87///     "Hello, World!"
88/// }
89///
90/// let app = Router::new().route("/", get(index));
91/// let addr = "[::]:8080".parse::<SocketAddr>().unwrap();
92/// let addr = Address::from(addr);
93///
94/// # tokio_test::block_on(async {
95/// Server::new(app).run(addr).await.unwrap();
96/// # })
97/// ```
98pub struct Server<S, L = Identity, SP = DefaultProvider> {
99    service: S,
100    layer: L,
101    server: auto::Builder<TokioExecutor>,
102    config: Config,
103    shutdown_hooks: Vec<Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>>,
104    span_provider: SP,
105    #[cfg(feature = "__tls")]
106    tls_config: Option<ServerTlsConfig>,
107}
108
109impl<S> Server<S, Identity, DefaultProvider> {
110    /// Create a new server.
111    pub fn new(service: S) -> Self {
112        Self {
113            service,
114            layer: Identity::new(),
115            server: auto::Builder::new(TokioExecutor::new()),
116            config: Config::default(),
117            shutdown_hooks: Vec::new(),
118            span_provider: DefaultProvider,
119            #[cfg(feature = "__tls")]
120            tls_config: None,
121        }
122    }
123}
124
125impl<S, L, SP> Server<S, L, SP> {
126    /// Enable TLS with the specified configuration.
127    ///
128    /// If not set, the server will not use TLS.
129    #[cfg(feature = "__tls")]
130    #[cfg_attr(docsrs, doc(cfg(any(feature = "rustls", feature = "native-tls"))))]
131    pub fn tls_config(mut self, config: impl Into<ServerTlsConfig>) -> Self {
132        self.tls_config = Some(config.into());
133        self.config.set_tls(true);
134        self
135    }
136
137    /// Register shutdown hook.
138    ///
139    /// Hook functions will be called just before volo's own gracefull existing code starts,
140    /// in reverse order of registration.
141    pub fn register_shutdown_hook(
142        mut self,
143        hook: impl FnOnce() -> BoxFuture<'static, ()> + 'static + Send,
144    ) -> Self {
145        self.shutdown_hooks.push(Box::new(hook));
146        self
147    }
148
149    /// Add a new inner layer to the server.
150    ///
151    /// The layer's [`Service`] should be `Send + Sync + Clone + 'static`.
152    ///
153    /// # Order
154    ///
155    /// Assume we already have two layers: foo and bar. We want to add a new layer baz.
156    ///
157    /// The current order is: foo -> bar (the request will come to foo first, and then bar).
158    ///
159    /// After we call `.layer(baz)`, we will get: foo -> bar -> baz.
160    pub fn layer<Inner>(self, layer: Inner) -> Server<S, Stack<Inner, L>, SP> {
161        Server {
162            service: self.service,
163            layer: Stack::new(layer, self.layer),
164            server: self.server,
165            config: self.config,
166            shutdown_hooks: self.shutdown_hooks,
167            span_provider: self.span_provider,
168            #[cfg(feature = "__tls")]
169            tls_config: self.tls_config,
170        }
171    }
172
173    /// Add a new front layer to the server.
174    ///
175    /// The layer's [`Service`] should be `Send + Sync + Clone + 'static`.
176    ///
177    /// # Order
178    ///
179    /// Assume we already have two layers: foo and bar. We want to add a new layer baz.
180    ///
181    /// The current order is: foo -> bar (the request will come to foo first, and then bar).
182    ///
183    /// After we call `.layer_front(baz)`, we will get: baz -> foo -> bar.
184    pub fn layer_front<Front>(self, layer: Front) -> Server<S, Stack<L, Front>, SP> {
185        Server {
186            service: self.service,
187            layer: Stack::new(self.layer, layer),
188            server: self.server,
189            config: self.config,
190            shutdown_hooks: self.shutdown_hooks,
191            span_provider: self.span_provider,
192            #[cfg(feature = "__tls")]
193            tls_config: self.tls_config,
194        }
195    }
196
197    /// Set a [`SpanProvider`] to the server.
198    ///
199    /// Server will enter the [`Span`] that created by [`SpanProvider::on_serve`] when starting to
200    /// serve a request, and call [`SpanProvider::leave_serve`] when leaving the serve function.
201    ///
202    /// [`Span`]: tracing::Span
203    pub fn span_provider<P>(self, span_provider: P) -> Server<S, L, P> {
204        Server {
205            service: self.service,
206            layer: self.layer,
207            server: self.server,
208            config: self.config,
209            shutdown_hooks: self.shutdown_hooks,
210            span_provider,
211            #[cfg(feature = "__tls")]
212            tls_config: self.tls_config,
213        }
214    }
215
216    /// This is unstable now and may be changed in the future.
217    #[doc(hidden)]
218    pub fn config(&self) -> &Config {
219        &self.config
220    }
221
222    /// This is unstable now and may be changed in the future.
223    #[doc(hidden)]
224    pub fn config_mut(&mut self) -> &mut Config {
225        &mut self.config
226    }
227
228    /// Set the maximum number of headers.
229    ///
230    /// When a request is received, the parser will reserve a buffer to store headers for optimal
231    /// performance.
232    ///
233    /// If server receives more headers than the buffer size, it responds to the client with
234    /// "431 Request Header Fields Too Large".
235    ///
236    /// Note that headers is allocated on the stack by default, which has higher performance. After
237    /// setting this value, headers will be allocated in heap memory, that is, heap memory
238    /// allocation will occur for each request, and there will be a performance drop of about 5%.
239    ///
240    /// Default is 100.
241    #[deprecated(
242        since = "0.4.0",
243        note = "`set_max_headers` has been removed into `http1_config`"
244    )]
245    #[cfg(feature = "http1")]
246    pub fn set_max_headers(&mut self, max_headers: usize) -> &mut Self {
247        self.server.http1().max_headers(max_headers);
248        self
249    }
250
251    /// Get configuration for http1 part.
252    #[cfg(feature = "http1")]
253    pub fn http1_config(&mut self) -> self::protocol::Http1Config<'_> {
254        self::protocol::Http1Config {
255            inner: self.server.http1(),
256        }
257    }
258
259    /// Get configuration for http2 part.
260    #[cfg(feature = "http2")]
261    pub fn http2_config(&mut self) -> self::protocol::Http2Config<'_> {
262        self::protocol::Http2Config {
263            inner: self.server.http2(),
264        }
265    }
266
267    /// Make server accept only HTTP/1.
268    #[cfg(feature = "http1")]
269    pub fn http1_only(self) -> Self {
270        Self {
271            service: self.service,
272            layer: self.layer,
273            server: self.server.http1_only(),
274            config: self.config,
275            shutdown_hooks: self.shutdown_hooks,
276            span_provider: self.span_provider,
277            #[cfg(feature = "__tls")]
278            tls_config: self.tls_config,
279        }
280    }
281
282    /// Make server accept only HTTP/2.
283    #[cfg(feature = "http2")]
284    pub fn http2_only(self) -> Self {
285        Self {
286            service: self.service,
287            layer: self.layer,
288            server: self.server.http2_only(),
289            config: self.config,
290            shutdown_hooks: self.shutdown_hooks,
291            span_provider: self.span_provider,
292            #[cfg(feature = "__tls")]
293            tls_config: self.tls_config,
294        }
295    }
296
297    /// The main entry point for the server.
298    pub async fn run<MI, B>(self, mk_incoming: MI) -> Result<(), BoxError>
299    where
300        S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
301        S::Response: IntoResponse,
302        S::Error: IntoResponse,
303        L: Layer<S> + Send + Sync + 'static,
304        L::Service: Service<ServerContext, Request, Error = Infallible> + Send + Sync + 'static,
305        <L::Service as Service<ServerContext, Request>>::Response: IntoResponse,
306        SP: SpanProvider + Clone + Send + Sync + Unpin + 'static,
307        MI: MakeIncoming,
308    {
309        let server = Arc::new(self.server);
310        let service = Arc::new(self.layer.layer(self.service));
311        let incoming = mk_incoming.make_incoming().await?;
312        tracing::info!("[Volo-HTTP] server start at: {:?}", incoming);
313
314        // count connections, used for graceful shutdown
315        let conn_cnt = Arc::new(AtomicUsize::new(0));
316        // flag for stopping serve
317        let exit_flag = Arc::new(parking_lot::RwLock::new(false));
318        // notifier for stopping all inflight connections
319        let exit_notify = Arc::new(Notify::const_new());
320
321        let handler = tokio::spawn(serve(
322            server,
323            incoming,
324            service,
325            self.config,
326            exit_flag.clone(),
327            conn_cnt.clone(),
328            exit_notify.clone(),
329            self.span_provider,
330            #[cfg(feature = "__tls")]
331            self.tls_config,
332        ));
333
334        #[cfg(target_family = "unix")]
335        {
336            // graceful shutdown
337            let mut sigint =
338                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
339            let mut sighup =
340                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())?;
341            let mut sigterm =
342                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
343
344            // graceful shutdown handler
345            tokio::select! {
346                _ = sigint.recv() => {}
347                _ = sighup.recv() => {}
348                _ = sigterm.recv() => {}
349                _ = handler => {},
350            }
351        }
352
353        // graceful shutdown handler for windows
354        #[cfg(target_family = "windows")]
355        tokio::select! {
356            _ = tokio::signal::ctrl_c() => {}
357            _ = handler => {},
358        }
359
360        if !self.shutdown_hooks.is_empty() {
361            tracing::info!("[Volo-HTTP] call shutdown hooks");
362
363            for hook in self.shutdown_hooks {
364                (hook)().await;
365            }
366        }
367
368        // received signal, graceful shutdown now
369        tracing::info!("[Volo-HTTP] received signal, gracefully exiting now");
370        *exit_flag.write() = true;
371
372        // Now we won't accept new connections.
373        // And we want to send crrst reply to the peers in the short future.
374        if conn_cnt.load(Ordering::Relaxed) != 0 {
375            tokio::time::sleep(Duration::from_secs(2)).await;
376        }
377        exit_notify.notify_waiters();
378
379        // wait for all connections to be closed
380        for _ in 0..28 {
381            if conn_cnt.load(Ordering::Relaxed) == 0 {
382                break;
383            }
384            tracing::trace!(
385                "[Volo-HTTP] gracefully exiting, remaining connection count: {}",
386                conn_cnt.load(Ordering::Relaxed),
387            );
388            tokio::time::sleep(Duration::from_secs(1)).await;
389        }
390
391        Ok(())
392    }
393}
394
395#[allow(clippy::too_many_arguments)]
396async fn serve<I, S, SP>(
397    server: Arc<auto::Builder<TokioExecutor>>,
398    mut incoming: I,
399    service: S,
400    config: Config,
401    exit_flag: Arc<RwLock<bool>>,
402    conn_cnt: Arc<AtomicUsize>,
403    exit_notify: Arc<Notify>,
404    span_provider: SP,
405    #[cfg(feature = "__tls")] tls_config: Option<ServerTlsConfig>,
406) where
407    I: Incoming,
408    S: Service<ServerContext, Request> + Clone + Unpin + Send + Sync + 'static,
409    S::Response: IntoResponse,
410    S::Error: IntoResponse,
411    SP: SpanProvider + Clone + Send + Sync + Unpin + 'static,
412{
413    loop {
414        if *exit_flag.read() {
415            break;
416        }
417
418        let conn = match incoming.accept().await {
419            Ok(Some(conn)) => conn,
420            _ => continue,
421        };
422        #[cfg(feature = "__tls")]
423        let conn = {
424            let Conn { stream, info } = conn;
425            match (stream, &tls_config) {
426                (ConnStream::Tcp(stream), Some(tls_config)) => {
427                    let stream = match tls_config.acceptor.accept(stream).await {
428                        Ok(conn) => conn,
429                        Err(err) => {
430                            tracing::trace!("[Volo-HTTP] tls handshake error: {err:?}");
431                            continue;
432                        }
433                    };
434                    Conn { stream, info }
435                }
436                (stream, _) => Conn { stream, info },
437            }
438        };
439
440        let peer = match conn.info.peer_addr {
441            Some(ref peer) => {
442                tracing::trace!("accept connection from: {peer:?}");
443                peer.clone()
444            }
445            None => {
446                tracing::info!("no peer address found from server connection");
447                continue;
448            }
449        };
450
451        let hyper_service = HyperService {
452            inner: service.clone(),
453            peer,
454            config: config.clone(),
455            span_provider: span_provider.clone(),
456        };
457
458        tokio::spawn(serve_conn(
459            server.clone(),
460            conn,
461            hyper_service,
462            conn_cnt.clone(),
463            exit_notify.clone(),
464        ));
465    }
466}
467
468async fn serve_conn<S>(
469    server: Arc<auto::Builder<TokioExecutor>>,
470    conn: Conn,
471    service: S,
472    conn_cnt: Arc<AtomicUsize>,
473    exit_notify: Arc<Notify>,
474) where
475    S: hyper::service::Service<HyperRequest, Response = Response> + Unpin,
476    S::Future: Send + 'static,
477    S::Error: Error + Send + Sync + 'static,
478{
479    conn_cnt.fetch_add(1, Ordering::Relaxed);
480    defer! {
481        conn_cnt.fetch_sub(1, Ordering::Relaxed);
482    }
483
484    let notified = exit_notify.notified();
485    tokio::pin!(notified);
486
487    let http_conn = server.serve_connection_with_upgrades(TokioIo::new(conn), service);
488    futures::pin_mut!(http_conn);
489
490    tokio::select! {
491        _ = &mut notified => {
492            tracing::trace!("[Volo-HTTP] closing a pending connection");
493            // Graceful shutdown.
494            http_conn.as_mut().graceful_shutdown();
495            // Continue to poll this connection until shutdown can finish.
496            let result = http_conn.as_mut().await;
497            if let Err(err) = result {
498                tracing::debug!("[Volo-HTTP] connection error: {:?}", err);
499            }
500        }
501        result = http_conn.as_mut() => {
502            if let Err(err) = result {
503                tracing::debug!("[Volo-HTTP] connection error: {:?}", err);
504            }
505        },
506    }
507}
508
509#[derive(Clone)]
510struct HyperService<S, SP> {
511    inner: S,
512    peer: Address,
513    config: Config,
514    span_provider: SP,
515}
516
517type HyperRequest = http::request::Request<hyper::body::Incoming>;
518
519impl<S, SP> hyper::service::Service<HyperRequest> for HyperService<S, SP>
520where
521    S: Service<ServerContext, Request> + Clone + Send + Sync + 'static,
522    S::Response: IntoResponse,
523    S::Error: IntoResponse,
524    SP: SpanProvider + Clone + Send + Sync + 'static,
525{
526    type Response = Response;
527    type Error = Infallible;
528    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
529
530    fn call(&self, req: HyperRequest) -> Self::Future {
531        let service = self.clone();
532        Box::pin(
533            METAINFO.scope(RefCell::new(MetaInfo::default()), async move {
534                let mut cx = ServerContext::new(service.peer);
535                cx.rpc_info_mut().set_config(service.config);
536                let span = service.span_provider.on_serve(&cx);
537                let resp = service
538                    .inner
539                    .call(&mut cx, req.map(Body::from_incoming))
540                    .instrument(span)
541                    .await
542                    .into_response();
543                service.span_provider.leave_serve(&cx);
544                Ok(resp)
545            }),
546        )
547    }
548}