salvo_core/
server.rs

1//! Server module
2use std::fmt::{self, Debug, Formatter};
3use std::io::Result as IoResult;
4use std::sync::Arc;
5#[cfg(feature = "server-handle")]
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8#[cfg(not(any(feature = "http1", feature = "http2", feature = "quinn")))]
9compile_error!(
10    "You have enabled `server` feature, it requires at least one of the following features: http1, http2, quinn."
11);
12
13#[cfg(feature = "http1")]
14use hyper::server::conn::http1;
15#[cfg(feature = "http2")]
16use hyper::server::conn::http2;
17#[cfg(feature = "server-handle")]
18use tokio::{
19    sync::{
20        Notify,
21        mpsc::{UnboundedReceiver, UnboundedSender},
22    },
23    time::Duration,
24};
25#[cfg(feature = "server-handle")]
26use tokio_util::sync::CancellationToken;
27
28use crate::Service;
29#[cfg(feature = "quinn")]
30use crate::conn::quinn;
31use crate::conn::{Accepted, Coupler, Acceptor, Holding, HttpBuilder};
32use crate::fuse::{ArcFuseFactory, FuseFactory};
33use crate::http::{HeaderValue,  Version};
34
35cfg_feature! {
36    #![feature ="server-handle"]
37    /// Server handle is used to stop server.
38    #[derive(Clone)]
39    pub struct ServerHandle {
40        tx_cmd: UnboundedSender<ServerCommand>,
41    }
42    impl Debug for ServerHandle {
43        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
44            f.debug_struct("ServerHandle").finish()
45        }
46    }
47}
48
49#[cfg(feature = "server-handle")]
50impl ServerHandle {
51    /// Force stop server.
52    ///
53    /// Call this function will stop server immediately.
54    pub fn stop_forcible(&self) {
55        let _ = self.tx_cmd.send(ServerCommand::StopForcible);
56    }
57
58    /// Graceful stop server.
59    ///
60    /// Call this function will stop server after all connections are closed,
61    /// allowing it to finish processing any ongoing requests before terminating.
62    /// It ensures that all connections are closed properly and any resources are released.
63    ///
64    /// You can specify a timeout to force stop server.
65    /// If `timeout` is `None`, it will wait until all connections are closed.
66    ///
67    /// This function gracefully stop the server, allowing it to finish processing any
68    /// ongoing requests before terminating. It ensures that all connections are closed
69    /// properly and any resources are released.
70    ///
71    /// # Examples
72    ///
73    /// ```no_run
74    /// use salvo_core::prelude::*;
75    ///
76    /// #[tokio::main]
77    /// async fn main() {
78    ///     let acceptor = TcpListener::new("127.0.0.1:8698").bind().await;
79    ///     let server = Server::new(acceptor);
80    ///     let handle = server.handle();
81    ///
82    ///     // Graceful shutdown the server
83    ///       tokio::spawn(async move {
84    ///         tokio::time::sleep(std::time::Duration::from_secs(60)).await;
85    ///         handle.stop_graceful(None);
86    ///     });
87    ///     server.serve(Router::new()).await;
88    /// }
89    /// ```
90    pub fn stop_graceful(&self, timeout: impl Into<Option<Duration>>) {
91        let _ = self
92            .tx_cmd
93            .send(ServerCommand::StopGraceful(timeout.into()));
94    }
95}
96
97#[cfg(feature = "server-handle")]
98enum ServerCommand {
99    StopForcible,
100    StopGraceful(Option<Duration>),
101}
102
103/// HTTP Server.
104///
105/// A `Server` is created to listen on a port, parse HTTP requests, and hand them off to a [`Service`].
106pub struct Server<A> {
107    acceptor: A,
108    builder: HttpBuilder,
109    fuse_factory: Option<ArcFuseFactory>,
110    #[cfg(feature = "server-handle")]
111    tx_cmd: UnboundedSender<ServerCommand>,
112    #[cfg(feature = "server-handle")]
113    rx_cmd: UnboundedReceiver<ServerCommand>,
114}
115
116impl<A> Debug for Server<A> {
117    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118        f.debug_struct("Server").finish()
119    }
120}
121
122impl<A: Acceptor + Send> Server<A> {
123    /// Create new `Server` with [`Acceptor`].
124    ///
125    /// # Example
126    ///
127    /// ```no_run
128    /// use salvo_core::prelude::*;
129    ///
130    /// #[tokio::main]
131    /// async fn main() {
132    ///     let acceptor = TcpListener::new("127.0.0.1:8698").bind().await;
133    ///     Server::new(acceptor);
134    /// }
135    /// ```
136    pub fn new(acceptor: A) -> Self {
137        Self::with_http_builder(acceptor, HttpBuilder::new())
138    }
139
140    /// Create new `Server` with [`Acceptor`] and [`HttpBuilder`].
141    pub fn with_http_builder(acceptor: A, builder: HttpBuilder) -> Self {
142        #[cfg(feature = "server-handle")]
143        let (tx_cmd, rx_cmd) = tokio::sync::mpsc::unbounded_channel();
144        Self {
145            acceptor,
146            builder,
147            fuse_factory: None,
148            #[cfg(feature = "server-handle")]
149            tx_cmd,
150            #[cfg(feature = "server-handle")]
151            rx_cmd,
152        }
153    }
154
155    /// Set the fuse factory.
156    #[must_use]
157    pub fn fuse_factory<F>(mut self, factory: F) -> Self
158    where
159        F: FuseFactory + Send + Sync + 'static,
160    {
161        self.fuse_factory = Some(Arc::new(factory));
162        self
163    }
164
165    cfg_feature! {
166        #![feature = "server-handle"]
167        /// Get a [`ServerHandle`] to stop server.
168        pub fn handle(&self) -> ServerHandle {
169            ServerHandle {
170                tx_cmd: self.tx_cmd.clone(),
171            }
172        }
173
174        /// Force stop server.
175        ///
176        /// Call this function will stop server immediately.
177        pub fn stop_forcible(&self) {
178            let _ = self.tx_cmd.send(ServerCommand::StopForcible);
179        }
180
181        /// Graceful stop server.
182        ///
183        /// Call this function will stop server after all connections are closed.
184        /// You can specify a timeout to force stop server.
185        /// If `timeout` is `None`, it will wait until all connections are closed.
186        pub fn stop_graceful(&self, timeout: impl Into<Option<Duration>>) {
187            let _ = self.tx_cmd.send(ServerCommand::StopGraceful(timeout.into()));
188        }
189    }
190
191    /// Get holding information of this server.
192    #[inline]
193    pub fn holdings(&self) -> &[Holding] {
194        self.acceptor.holdings()
195    }
196
197    cfg_feature! {
198        #![feature = "http1"]
199        /// Use this function to set http1 protocol.
200        pub fn http1_mut(&mut self) -> &mut http1::Builder {
201            &mut self.builder.http1
202        }
203    }
204
205    cfg_feature! {
206        #![feature = "http2"]
207        /// Use this function to set http2 protocol.
208        pub fn http2_mut(&mut self) -> &mut http2::Builder<crate::rt::tokio::TokioExecutor> {
209            &mut self.builder.http2
210        }
211    }
212
213    cfg_feature! {
214        #![feature = "quinn"]
215        /// Use this function to set http3 protocol.
216        pub fn quinn_mut(&mut self) -> &mut quinn::Builder {
217            &mut self.builder.quinn
218        }
219    }
220
221    /// Serve a [`Service`].
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// use salvo_core::prelude::*;
227    /// #[handler]
228    /// async fn hello() -> &'static str {
229    ///     "Hello World"
230    /// }
231    ///
232    /// #[tokio::main]
233    /// async fn main() {
234    ///     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
235    ///     let router = Router::new().get(hello);
236    ///     Server::new(acceptor).serve(router).await;
237    /// }
238    /// ```
239    #[inline]
240    pub async fn serve<S>(self, service: S)
241    where
242        S: Into<Service> + Send,
243    {
244        self.try_serve(service)
245            .await
246            .expect("failed to call `Server::serve`");
247    }
248
249    /// Try to serve a [`Service`].
250    #[cfg(feature = "server-handle")]
251    #[allow(clippy::manual_async_fn)] //Fix: https://github.com/salvo-rs/salvo/issues/902
252    pub fn try_serve<S>(self, service: S) -> impl Future<Output = IoResult<()>> + Send
253    where
254        S: Into<Service> + Send,
255    {
256        async {
257            let Self {
258                mut acceptor,
259                builder,
260                fuse_factory,
261                mut rx_cmd,
262                ..
263            } = self;
264            let alive_connections = Arc::new(AtomicUsize::new(0));
265            let notify = Arc::new(Notify::new());
266            let force_stop_token = CancellationToken::new();
267            let graceful_stop_token = CancellationToken::new();
268
269            let mut alt_svc_h3 = None;
270            for holding in acceptor.holdings() {
271                tracing::info!("listening {}", holding);
272                if holding.http_versions.contains(&Version::HTTP_3) {
273                    if let Some(addr) = holding.local_addr.clone().into_std() {
274                        let port = addr.port();
275                        alt_svc_h3 = Some(
276                            format!(r#"h3=":{port}"; ma=2592000,h3-29=":{port}"; ma=2592000"#)
277                                .parse::<HeaderValue>()
278                                .expect("Parse alt-svc header should not failed."),
279                        );
280                    }
281                }
282            }
283
284            let service: Arc<Service> = Arc::new(service.into());
285            let builder = Arc::new(builder);
286            loop {
287                tokio::select! {
288                    accepted = acceptor.accept(fuse_factory.clone()) => {
289                        match accepted {
290                            Ok(Accepted { coupler, stream, fusewire, local_addr, remote_addr, http_scheme, ..}) => {
291                                alive_connections.fetch_add(1, Ordering::Release);
292
293                                let service = service.clone();
294                                let alive_connections = alive_connections.clone();
295                                let notify = notify.clone();
296                                let handler = service.hyper_handler(local_addr, remote_addr, http_scheme, fusewire, alt_svc_h3.clone());
297                                let builder = builder.clone();
298
299                                let force_stop_token = force_stop_token.clone();
300                                let graceful_stop_token = graceful_stop_token.clone();
301
302                                tokio::spawn(async move {
303                                    let conn = coupler.couple(stream, handler, builder, Some(graceful_stop_token.clone()));
304                                    tokio::select! {
305                                        _ = conn => {
306                                        },
307                                        _ = force_stop_token.cancelled() => {
308                                        }
309                                    }
310
311                                    if alive_connections.fetch_sub(1, Ordering::Acquire) == 1 {
312                                        // notify only if shutdown is initiated, to prevent notification when server is active.
313                                        // It's a valid state to have 0 alive connections when server is not shutting down.
314                                        if graceful_stop_token.is_cancelled() {
315                                            notify.notify_one();
316                                        }
317                                    }
318                                });
319                            },
320                            Err(e) => {
321                                tracing::error!(error = ?e, "accept connection failed");
322                            }
323                        }
324                    }
325                    Some(cmd) = rx_cmd.recv() => {
326                        match cmd {
327                            ServerCommand::StopGraceful(timeout) => {
328                                let graceful_stop_token = graceful_stop_token.clone();
329                                graceful_stop_token.cancel();
330                                if let Some(timeout) = timeout {
331                                    tracing::info!(
332                                        timeout_in_seconds = timeout.as_secs_f32(),
333                                        "initiate graceful stop server",
334                                    );
335
336                                    let force_stop_token = force_stop_token.clone();
337                                    tokio::spawn(async move {
338                                        tokio::time::sleep(timeout).await;
339                                        force_stop_token.cancel();
340                                    });
341                                } else {
342                                    tracing::info!("initiate graceful stop server");
343                                }
344                            },
345                            ServerCommand::StopForcible => {
346                                tracing::info!("force stop server");
347                                force_stop_token.cancel();
348                            },
349                        }
350                        break;
351                    },
352                }
353            }
354
355            if !force_stop_token.is_cancelled() && alive_connections.load(Ordering::Acquire) > 0 {
356                tracing::info!(
357                    "wait for {} connections to close.",
358                    alive_connections.load(Ordering::Acquire)
359                );
360                notify.notified().await;
361            }
362
363            tracing::info!("server stopped");
364            Ok(())
365        }
366    }
367    /// Try to serve a [`Service`].
368    #[cfg(not(feature = "server-handle"))]
369    pub async fn try_serve<S>(self, service: S) -> IoResult<()>
370    where
371        S: Into<Service> + Send,
372    {
373        let Self {
374            mut acceptor,
375            builder,
376            fuse_factory,
377            ..
378        } = self;
379        let mut alt_svc_h3 = None;
380        for holding in acceptor.holdings() {
381            tracing::info!("listening {}", holding);
382            if holding.http_versions.contains(&Version::HTTP_3) {
383                if let Some(addr) = holding.local_addr.clone().into_std() {
384                    let port = addr.port();
385                    alt_svc_h3 = Some(
386                        format!(r#"h3=":{port}"; ma=2592000,h3-29=":{port}"; ma=2592000"#)
387                            .parse::<HeaderValue>()
388                            .expect("Parse alt-svc header should not failed."),
389                    );
390                }
391            }
392        }
393
394        let service: Arc<Service> = Arc::new(service.into());
395        let builder = Arc::new(builder);
396        loop {
397            match acceptor.accept(fuse_factory.clone()).await {
398                Ok(Accepted {
399                    coupler,
400                    stream,
401                    fusewire,
402                    local_addr,
403                    remote_addr,
404                    http_scheme,
405                    ..
406                }) => {
407                    let service = service.clone();
408                    let handler = service.hyper_handler(
409                        local_addr,
410                        remote_addr,
411                        http_scheme,
412                        fusewire,
413                        alt_svc_h3.clone(),
414                    );
415                    let builder = builder.clone();
416
417                    tokio::spawn(async move {
418                        let _ = coupler.couple(stream, handler, builder, None).await;
419                    });
420                }
421                Err(e) => {
422                    tracing::error!(error = ?e, "accept connection failed");
423                }
424            }
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use serde::Serialize;
432
433    use crate::prelude::*;
434    use crate::test::{ResponseExt, TestClient};
435
436    #[tokio::test]
437    async fn test_server() {
438        #[handler]
439        async fn hello() -> Result<&'static str, ()> {
440            Ok("Hello World")
441        }
442        #[handler]
443        async fn json(res: &mut Response) {
444            #[derive(Serialize, Debug)]
445            struct User {
446                name: String,
447            }
448            res.render(Json(User {
449                name: "jobs".into(),
450            }));
451        }
452        let router = Router::new()
453            .get(hello)
454            .push(Router::with_path("json").get(json));
455        let service = Service::new(router);
456
457        let base_url = "http://127.0.0.1:8698";
458        let result = TestClient::get(base_url)
459            .send(&service)
460            .await
461            .take_string()
462            .await
463            .unwrap();
464        assert_eq!(result, "Hello World");
465
466        let result = TestClient::get(format!("{base_url}/json"))
467            .send(&service)
468            .await
469            .take_string()
470            .await
471            .unwrap();
472        assert_eq!(result, r#"{"name":"jobs"}"#);
473
474        let result = TestClient::get(format!("{base_url}/not_exist"))
475            .send(&service)
476            .await
477            .take_string()
478            .await
479            .unwrap();
480        assert!(result.contains("Not Found"));
481        let result = TestClient::get(format!("{base_url}/not_exist"))
482            .add_header("accept", "application/json", true)
483            .send(&service)
484            .await
485            .take_string()
486            .await
487            .unwrap();
488        assert!(result.contains(r#""code":404"#));
489        let result = TestClient::get(format!("{base_url}/not_exist"))
490            .add_header("accept", "text/plain", true)
491            .send(&service)
492            .await
493            .take_string()
494            .await
495            .unwrap();
496        assert!(result.contains("code: 404"));
497        let result = TestClient::get(format!("{base_url}/not_exist"))
498            .add_header("accept", "application/xml", true)
499            .send(&service)
500            .await
501            .take_string()
502            .await
503            .unwrap();
504        assert!(result.contains("<code>404</code>"));
505    }
506
507    #[cfg(feature = "server-handle")]
508    #[tokio::test]
509    async fn test_server_handle_stop() {
510        use std::time::Duration;
511        use tokio::time::timeout;
512
513        // Test forcible stop
514        let acceptor = crate::conn::TcpListener::new("127.0.0.1:5802").bind().await;
515        let server = Server::new(acceptor);
516        let handle = server.handle();
517        let server_task = tokio::spawn(server.try_serve(Router::new()));
518
519        // Give server a moment to start
520        tokio::time::sleep(Duration::from_millis(50)).await;
521
522        handle.stop_forcible();
523
524        let result = timeout(Duration::from_secs(1), server_task).await;
525        assert!(result.is_ok(), "Server should stop forcibly within 1 second.");
526        let server_result = result.unwrap();
527        assert!(server_result.is_ok(), "Server task should not panic.");
528        assert!(server_result.unwrap().is_ok(), "try_serve should return Ok.");
529
530        // Test graceful stop
531        let acceptor = crate::conn::TcpListener::new("127.0.0.1:5803").bind().await;
532        let server = Server::new(acceptor);
533        let handle = server.handle();
534        let server_task = tokio::spawn(server.try_serve(Router::new()));
535
536        // Give server a moment to start
537        tokio::time::sleep(Duration::from_millis(50)).await;
538
539        handle.stop_graceful(None);
540
541        let result = timeout(Duration::from_secs(1), server_task).await;
542        assert!(result.is_ok(), "Server should stop gracefully within 1 second.");
543        let server_result = result.unwrap();
544        assert!(server_result.is_ok(), "Server task should not panic.");
545        assert!(server_result.unwrap().is_ok(), "try_serve should return Ok.");
546    }
547
548    #[test]
549    fn test_regression_209() {
550        #[cfg(feature = "acme")]
551        let _: &dyn Send = &async {
552            let acceptor = TcpListener::new("127.0.0.1:0")
553                .acme()
554                .add_domain("test.salvo.rs")
555                .bind()
556                .await;
557            Server::new(acceptor).serve(Router::new()).await;
558        };
559        #[cfg(feature = "native-tls")]
560        let _: &dyn Send = &async {
561            use crate::conn::native_tls::NativeTlsConfig;
562
563            let identity = if cfg!(target_os = "macos") {
564                include_bytes!("../certs/identity-legacy.p12").to_vec()
565            } else {
566                include_bytes!("../certs/identity.p12").to_vec()
567            };
568            let acceptor = TcpListener::new("127.0.0.1:0")
569                .native_tls(NativeTlsConfig::new().pkcs12(identity).password("mypass"))
570                .bind()
571                .await;
572            Server::new(acceptor).serve(Router::new()).await;
573        };
574        #[cfg(feature = "openssl")]
575        let _: &dyn Send = &async {
576            use crate::conn::openssl::{Keycert, OpensslConfig};
577
578            let acceptor = TcpListener::new("127.0.0.1:0")
579                .openssl(OpensslConfig::new(
580                    Keycert::new()
581                        .key_from_path("certs/key.pem")
582                        .unwrap()
583                        .cert_from_path("certs/cert.pem")
584                        .unwrap(),
585                ))
586                .bind()
587                .await;
588            Server::new(acceptor).serve(Router::new()).await;
589        };
590        #[cfg(feature = "rustls")]
591        let _: &dyn Send = &async {
592            use crate::conn::rustls::{Keycert, RustlsConfig};
593
594            let acceptor = TcpListener::new("127.0.0.1:0")
595                .rustls(RustlsConfig::new(
596                    Keycert::new()
597                        .key_from_path("certs/key.pem")
598                        .unwrap()
599                        .cert_from_path("certs/cert.pem")
600                        .unwrap(),
601                ))
602                .bind()
603                .await;
604            Server::new(acceptor).serve(Router::new()).await;
605        };
606        #[cfg(feature = "quinn")]
607        let _: &dyn Send = &async {
608            use crate::conn::rustls::{Keycert, RustlsConfig};
609
610            let cert = include_bytes!("../certs/cert.pem").to_vec();
611            let key = include_bytes!("../certs/key.pem").to_vec();
612            let config =
613                RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice()));
614            let listener = TcpListener::new(("127.0.0.1", 2048)).rustls(config.clone());
615            let acceptor = QuinnListener::new(config, ("127.0.0.1", 2048))
616                .join(listener)
617                .bind()
618                .await;
619            Server::new(acceptor).serve(Router::new()).await;
620        };
621        let _: &dyn Send = &async {
622            let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 6878));
623            let acceptor = TcpListener::new(addr).bind().await;
624            Server::new(acceptor).serve(Router::new()).await;
625        };
626        #[cfg(unix)]
627        let _: &dyn Send = &async {
628            use crate::conn::UnixListener;
629
630            let sock_file = "/tmp/test-salvo.sock";
631            let acceptor = UnixListener::new(sock_file).bind().await;
632            Server::new(acceptor).serve(Router::new()).await;
633        };
634    }
635}