1use std::fmt::{self, Debug, Formatter};
3use std::io::Result as IoResult;
4use std::sync::Arc;
5#[cfg(feature = "server-handle")]
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8#[cfg(not(any(feature = "http1", feature = "http2", feature = "quinn")))]
9compile_error!(
10 "You have enabled `server` feature, it requires at least one of the following features: http1, http2, quinn."
11);
12
13#[cfg(feature = "http1")]
14use hyper::server::conn::http1;
15#[cfg(feature = "http2")]
16use hyper::server::conn::http2;
17#[cfg(feature = "server-handle")]
18use tokio::{
19 sync::{
20 Notify,
21 mpsc::{UnboundedReceiver, UnboundedSender},
22 },
23 time::Duration,
24};
25#[cfg(feature = "server-handle")]
26use tokio_util::sync::CancellationToken;
27
28use crate::Service;
29#[cfg(feature = "quinn")]
30use crate::conn::quinn;
31use crate::conn::{Accepted, Coupler, Acceptor, Holding, HttpBuilder};
32use crate::fuse::{ArcFuseFactory, FuseFactory};
33use crate::http::{HeaderValue, Version};
34
35cfg_feature! {
36 #![feature ="server-handle"]
37 #[derive(Clone)]
39 pub struct ServerHandle {
40 tx_cmd: UnboundedSender<ServerCommand>,
41 }
42 impl Debug for ServerHandle {
43 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
44 f.debug_struct("ServerHandle").finish()
45 }
46 }
47}
48
49#[cfg(feature = "server-handle")]
50impl ServerHandle {
51 pub fn stop_forcible(&self) {
55 let _ = self.tx_cmd.send(ServerCommand::StopForcible);
56 }
57
58 pub fn stop_graceful(&self, timeout: impl Into<Option<Duration>>) {
91 let _ = self
92 .tx_cmd
93 .send(ServerCommand::StopGraceful(timeout.into()));
94 }
95}
96
97#[cfg(feature = "server-handle")]
98enum ServerCommand {
99 StopForcible,
100 StopGraceful(Option<Duration>),
101}
102
103pub struct Server<A> {
107 acceptor: A,
108 builder: HttpBuilder,
109 fuse_factory: Option<ArcFuseFactory>,
110 #[cfg(feature = "server-handle")]
111 tx_cmd: UnboundedSender<ServerCommand>,
112 #[cfg(feature = "server-handle")]
113 rx_cmd: UnboundedReceiver<ServerCommand>,
114}
115
116impl<A> Debug for Server<A> {
117 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118 f.debug_struct("Server").finish()
119 }
120}
121
122impl<A: Acceptor + Send> Server<A> {
123 pub fn new(acceptor: A) -> Self {
137 Self::with_http_builder(acceptor, HttpBuilder::new())
138 }
139
140 pub fn with_http_builder(acceptor: A, builder: HttpBuilder) -> Self {
142 #[cfg(feature = "server-handle")]
143 let (tx_cmd, rx_cmd) = tokio::sync::mpsc::unbounded_channel();
144 Self {
145 acceptor,
146 builder,
147 fuse_factory: None,
148 #[cfg(feature = "server-handle")]
149 tx_cmd,
150 #[cfg(feature = "server-handle")]
151 rx_cmd,
152 }
153 }
154
155 #[must_use]
157 pub fn fuse_factory<F>(mut self, factory: F) -> Self
158 where
159 F: FuseFactory + Send + Sync + 'static,
160 {
161 self.fuse_factory = Some(Arc::new(factory));
162 self
163 }
164
165 cfg_feature! {
166 #![feature = "server-handle"]
167 pub fn handle(&self) -> ServerHandle {
169 ServerHandle {
170 tx_cmd: self.tx_cmd.clone(),
171 }
172 }
173
174 pub fn stop_forcible(&self) {
178 let _ = self.tx_cmd.send(ServerCommand::StopForcible);
179 }
180
181 pub fn stop_graceful(&self, timeout: impl Into<Option<Duration>>) {
187 let _ = self.tx_cmd.send(ServerCommand::StopGraceful(timeout.into()));
188 }
189 }
190
191 #[inline]
193 pub fn holdings(&self) -> &[Holding] {
194 self.acceptor.holdings()
195 }
196
197 cfg_feature! {
198 #![feature = "http1"]
199 pub fn http1_mut(&mut self) -> &mut http1::Builder {
201 &mut self.builder.http1
202 }
203 }
204
205 cfg_feature! {
206 #![feature = "http2"]
207 pub fn http2_mut(&mut self) -> &mut http2::Builder<crate::rt::tokio::TokioExecutor> {
209 &mut self.builder.http2
210 }
211 }
212
213 cfg_feature! {
214 #![feature = "quinn"]
215 pub fn quinn_mut(&mut self) -> &mut quinn::Builder {
217 &mut self.builder.quinn
218 }
219 }
220
221 #[inline]
240 pub async fn serve<S>(self, service: S)
241 where
242 S: Into<Service> + Send,
243 {
244 self.try_serve(service)
245 .await
246 .expect("failed to call `Server::serve`");
247 }
248
249 #[cfg(feature = "server-handle")]
251 #[allow(clippy::manual_async_fn)] pub fn try_serve<S>(self, service: S) -> impl Future<Output = IoResult<()>> + Send
253 where
254 S: Into<Service> + Send,
255 {
256 async {
257 let Self {
258 mut acceptor,
259 builder,
260 fuse_factory,
261 mut rx_cmd,
262 ..
263 } = self;
264 let alive_connections = Arc::new(AtomicUsize::new(0));
265 let notify = Arc::new(Notify::new());
266 let force_stop_token = CancellationToken::new();
267 let graceful_stop_token = CancellationToken::new();
268
269 let mut alt_svc_h3 = None;
270 for holding in acceptor.holdings() {
271 tracing::info!("listening {}", holding);
272 if holding.http_versions.contains(&Version::HTTP_3) {
273 if let Some(addr) = holding.local_addr.clone().into_std() {
274 let port = addr.port();
275 alt_svc_h3 = Some(
276 format!(r#"h3=":{port}"; ma=2592000,h3-29=":{port}"; ma=2592000"#)
277 .parse::<HeaderValue>()
278 .expect("Parse alt-svc header should not failed."),
279 );
280 }
281 }
282 }
283
284 let service: Arc<Service> = Arc::new(service.into());
285 let builder = Arc::new(builder);
286 loop {
287 tokio::select! {
288 accepted = acceptor.accept(fuse_factory.clone()) => {
289 match accepted {
290 Ok(Accepted { coupler, stream, fusewire, local_addr, remote_addr, http_scheme, ..}) => {
291 alive_connections.fetch_add(1, Ordering::Release);
292
293 let service = service.clone();
294 let alive_connections = alive_connections.clone();
295 let notify = notify.clone();
296 let handler = service.hyper_handler(local_addr, remote_addr, http_scheme, fusewire, alt_svc_h3.clone());
297 let builder = builder.clone();
298
299 let force_stop_token = force_stop_token.clone();
300 let graceful_stop_token = graceful_stop_token.clone();
301
302 tokio::spawn(async move {
303 let conn = coupler.couple(stream, handler, builder, Some(graceful_stop_token.clone()));
304 tokio::select! {
305 _ = conn => {
306 },
307 _ = force_stop_token.cancelled() => {
308 }
309 }
310
311 if alive_connections.fetch_sub(1, Ordering::Acquire) == 1 {
312 if graceful_stop_token.is_cancelled() {
315 notify.notify_one();
316 }
317 }
318 });
319 },
320 Err(e) => {
321 tracing::error!(error = ?e, "accept connection failed");
322 }
323 }
324 }
325 Some(cmd) = rx_cmd.recv() => {
326 match cmd {
327 ServerCommand::StopGraceful(timeout) => {
328 let graceful_stop_token = graceful_stop_token.clone();
329 graceful_stop_token.cancel();
330 if let Some(timeout) = timeout {
331 tracing::info!(
332 timeout_in_seconds = timeout.as_secs_f32(),
333 "initiate graceful stop server",
334 );
335
336 let force_stop_token = force_stop_token.clone();
337 tokio::spawn(async move {
338 tokio::time::sleep(timeout).await;
339 force_stop_token.cancel();
340 });
341 } else {
342 tracing::info!("initiate graceful stop server");
343 }
344 },
345 ServerCommand::StopForcible => {
346 tracing::info!("force stop server");
347 force_stop_token.cancel();
348 },
349 }
350 break;
351 },
352 }
353 }
354
355 if !force_stop_token.is_cancelled() && alive_connections.load(Ordering::Acquire) > 0 {
356 tracing::info!(
357 "wait for {} connections to close.",
358 alive_connections.load(Ordering::Acquire)
359 );
360 notify.notified().await;
361 }
362
363 tracing::info!("server stopped");
364 Ok(())
365 }
366 }
367 #[cfg(not(feature = "server-handle"))]
369 pub async fn try_serve<S>(self, service: S) -> IoResult<()>
370 where
371 S: Into<Service> + Send,
372 {
373 let Self {
374 mut acceptor,
375 builder,
376 fuse_factory,
377 ..
378 } = self;
379 let mut alt_svc_h3 = None;
380 for holding in acceptor.holdings() {
381 tracing::info!("listening {}", holding);
382 if holding.http_versions.contains(&Version::HTTP_3) {
383 if let Some(addr) = holding.local_addr.clone().into_std() {
384 let port = addr.port();
385 alt_svc_h3 = Some(
386 format!(r#"h3=":{port}"; ma=2592000,h3-29=":{port}"; ma=2592000"#)
387 .parse::<HeaderValue>()
388 .expect("Parse alt-svc header should not failed."),
389 );
390 }
391 }
392 }
393
394 let service: Arc<Service> = Arc::new(service.into());
395 let builder = Arc::new(builder);
396 loop {
397 match acceptor.accept(fuse_factory.clone()).await {
398 Ok(Accepted {
399 coupler,
400 stream,
401 fusewire,
402 local_addr,
403 remote_addr,
404 http_scheme,
405 ..
406 }) => {
407 let service = service.clone();
408 let handler = service.hyper_handler(
409 local_addr,
410 remote_addr,
411 http_scheme,
412 fusewire,
413 alt_svc_h3.clone(),
414 );
415 let builder = builder.clone();
416
417 tokio::spawn(async move {
418 let _ = coupler.couple(stream, handler, builder, None).await;
419 });
420 }
421 Err(e) => {
422 tracing::error!(error = ?e, "accept connection failed");
423 }
424 }
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use serde::Serialize;
432
433 use crate::prelude::*;
434 use crate::test::{ResponseExt, TestClient};
435
436 #[tokio::test]
437 async fn test_server() {
438 #[handler]
439 async fn hello() -> Result<&'static str, ()> {
440 Ok("Hello World")
441 }
442 #[handler]
443 async fn json(res: &mut Response) {
444 #[derive(Serialize, Debug)]
445 struct User {
446 name: String,
447 }
448 res.render(Json(User {
449 name: "jobs".into(),
450 }));
451 }
452 let router = Router::new()
453 .get(hello)
454 .push(Router::with_path("json").get(json));
455 let service = Service::new(router);
456
457 let base_url = "http://127.0.0.1:8698";
458 let result = TestClient::get(base_url)
459 .send(&service)
460 .await
461 .take_string()
462 .await
463 .unwrap();
464 assert_eq!(result, "Hello World");
465
466 let result = TestClient::get(format!("{base_url}/json"))
467 .send(&service)
468 .await
469 .take_string()
470 .await
471 .unwrap();
472 assert_eq!(result, r#"{"name":"jobs"}"#);
473
474 let result = TestClient::get(format!("{base_url}/not_exist"))
475 .send(&service)
476 .await
477 .take_string()
478 .await
479 .unwrap();
480 assert!(result.contains("Not Found"));
481 let result = TestClient::get(format!("{base_url}/not_exist"))
482 .add_header("accept", "application/json", true)
483 .send(&service)
484 .await
485 .take_string()
486 .await
487 .unwrap();
488 assert!(result.contains(r#""code":404"#));
489 let result = TestClient::get(format!("{base_url}/not_exist"))
490 .add_header("accept", "text/plain", true)
491 .send(&service)
492 .await
493 .take_string()
494 .await
495 .unwrap();
496 assert!(result.contains("code: 404"));
497 let result = TestClient::get(format!("{base_url}/not_exist"))
498 .add_header("accept", "application/xml", true)
499 .send(&service)
500 .await
501 .take_string()
502 .await
503 .unwrap();
504 assert!(result.contains("<code>404</code>"));
505 }
506
507 #[cfg(feature = "server-handle")]
508 #[tokio::test]
509 async fn test_server_handle_stop() {
510 use std::time::Duration;
511 use tokio::time::timeout;
512
513 let acceptor = crate::conn::TcpListener::new("127.0.0.1:5802").bind().await;
515 let server = Server::new(acceptor);
516 let handle = server.handle();
517 let server_task = tokio::spawn(server.try_serve(Router::new()));
518
519 tokio::time::sleep(Duration::from_millis(50)).await;
521
522 handle.stop_forcible();
523
524 let result = timeout(Duration::from_secs(1), server_task).await;
525 assert!(result.is_ok(), "Server should stop forcibly within 1 second.");
526 let server_result = result.unwrap();
527 assert!(server_result.is_ok(), "Server task should not panic.");
528 assert!(server_result.unwrap().is_ok(), "try_serve should return Ok.");
529
530 let acceptor = crate::conn::TcpListener::new("127.0.0.1:5803").bind().await;
532 let server = Server::new(acceptor);
533 let handle = server.handle();
534 let server_task = tokio::spawn(server.try_serve(Router::new()));
535
536 tokio::time::sleep(Duration::from_millis(50)).await;
538
539 handle.stop_graceful(None);
540
541 let result = timeout(Duration::from_secs(1), server_task).await;
542 assert!(result.is_ok(), "Server should stop gracefully within 1 second.");
543 let server_result = result.unwrap();
544 assert!(server_result.is_ok(), "Server task should not panic.");
545 assert!(server_result.unwrap().is_ok(), "try_serve should return Ok.");
546 }
547
548 #[test]
549 fn test_regression_209() {
550 #[cfg(feature = "acme")]
551 let _: &dyn Send = &async {
552 let acceptor = TcpListener::new("127.0.0.1:0")
553 .acme()
554 .add_domain("test.salvo.rs")
555 .bind()
556 .await;
557 Server::new(acceptor).serve(Router::new()).await;
558 };
559 #[cfg(feature = "native-tls")]
560 let _: &dyn Send = &async {
561 use crate::conn::native_tls::NativeTlsConfig;
562
563 let identity = if cfg!(target_os = "macos") {
564 include_bytes!("../certs/identity-legacy.p12").to_vec()
565 } else {
566 include_bytes!("../certs/identity.p12").to_vec()
567 };
568 let acceptor = TcpListener::new("127.0.0.1:0")
569 .native_tls(NativeTlsConfig::new().pkcs12(identity).password("mypass"))
570 .bind()
571 .await;
572 Server::new(acceptor).serve(Router::new()).await;
573 };
574 #[cfg(feature = "openssl")]
575 let _: &dyn Send = &async {
576 use crate::conn::openssl::{Keycert, OpensslConfig};
577
578 let acceptor = TcpListener::new("127.0.0.1:0")
579 .openssl(OpensslConfig::new(
580 Keycert::new()
581 .key_from_path("certs/key.pem")
582 .unwrap()
583 .cert_from_path("certs/cert.pem")
584 .unwrap(),
585 ))
586 .bind()
587 .await;
588 Server::new(acceptor).serve(Router::new()).await;
589 };
590 #[cfg(feature = "rustls")]
591 let _: &dyn Send = &async {
592 use crate::conn::rustls::{Keycert, RustlsConfig};
593
594 let acceptor = TcpListener::new("127.0.0.1:0")
595 .rustls(RustlsConfig::new(
596 Keycert::new()
597 .key_from_path("certs/key.pem")
598 .unwrap()
599 .cert_from_path("certs/cert.pem")
600 .unwrap(),
601 ))
602 .bind()
603 .await;
604 Server::new(acceptor).serve(Router::new()).await;
605 };
606 #[cfg(feature = "quinn")]
607 let _: &dyn Send = &async {
608 use crate::conn::rustls::{Keycert, RustlsConfig};
609
610 let cert = include_bytes!("../certs/cert.pem").to_vec();
611 let key = include_bytes!("../certs/key.pem").to_vec();
612 let config =
613 RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice()));
614 let listener = TcpListener::new(("127.0.0.1", 2048)).rustls(config.clone());
615 let acceptor = QuinnListener::new(config, ("127.0.0.1", 2048))
616 .join(listener)
617 .bind()
618 .await;
619 Server::new(acceptor).serve(Router::new()).await;
620 };
621 let _: &dyn Send = &async {
622 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 6878));
623 let acceptor = TcpListener::new(addr).bind().await;
624 Server::new(acceptor).serve(Router::new()).await;
625 };
626 #[cfg(unix)]
627 let _: &dyn Send = &async {
628 use crate::conn::UnixListener;
629
630 let sock_file = "/tmp/test-salvo.sock";
631 let acceptor = UnixListener::new(sock_file).bind().await;
632 Server::new(acceptor).serve(Router::new()).await;
633 };
634 }
635}