1use super::{banner, config};
7use crate::error::XOneError;
8use crate::xserver::Server;
9use crate::xutil;
10use axum::Router;
11use axum::serve::ListenerExt;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::sync::watch;
15
16pub struct XAxumServer {
32 router: Router,
33 addr: Option<SocketAddr>,
35 enable_banner: Option<bool>,
37 use_http2: Option<bool>,
39 shutdown_tx: watch::Sender<bool>,
40}
41
42impl XAxumServer {
43 pub(crate) fn new(
45 router: Router,
46 addr: Option<SocketAddr>,
47 enable_banner: Option<bool>,
48 use_http2: Option<bool>,
49 ) -> Self {
50 let (shutdown_tx, _) = watch::channel(false);
51 Self {
52 router,
53 addr,
54 enable_banner,
55 use_http2,
56 shutdown_tx,
57 }
58 }
59
60 pub fn addr(&self) -> SocketAddr {
64 self.addr.unwrap_or_else(|| resolve_config().0)
65 }
66
67 pub fn use_http2(&self) -> bool {
71 self.use_http2.unwrap_or_else(|| resolve_config().2)
72 }
73
74 pub fn into_router(self) -> Router {
76 self.router
77 }
78
79 async fn run_http1(
81 &self,
82 listener: tokio::net::TcpListener,
83 mut shutdown_rx: watch::Receiver<bool>,
84 ) -> Result<(), XOneError> {
85 let listener = listener.tap_io(|tcp_stream| {
86 let _ = tcp_stream.set_nodelay(true);
87 });
88 axum::serve(listener, self.router.clone())
89 .with_graceful_shutdown(async move {
90 let _ = shutdown_rx.changed().await;
91 })
92 .await
93 .map_err(|e| XOneError::Server(format!("server error: {e}")))?;
94 Ok(())
95 }
96
97 async fn run_h2c(
99 &self,
100 listener: tokio::net::TcpListener,
101 mut shutdown_rx: watch::Receiver<bool>,
102 ) -> Result<(), XOneError> {
103 use hyper::body::Incoming;
104 use hyper_util::rt::{TokioExecutor, TokioIo};
105 use hyper_util::server::conn::auto;
106 use hyper_util::server::graceful::GracefulShutdown;
107 use tower_service::Service;
108
109 let graceful = GracefulShutdown::new();
110
111 loop {
112 tokio::select! {
113 result = listener.accept() => {
114 let (socket, _remote_addr) = result
115 .map_err(|e| XOneError::Server(format!("accept failed: {e}")))?;
116 let _ = socket.set_nodelay(true);
117
118 let tower_service = self.router.clone();
119
120 let hyper_service = hyper::service::service_fn(
121 move |request: axum::extract::Request<Incoming>| {
122 tower_service.clone().call(request)
123 },
124 );
125
126 let builder = auto::Builder::new(TokioExecutor::new());
127 let conn = builder
128 .serve_connection_with_upgrades(
129 TokioIo::new(socket),
130 hyper_service,
131 );
132
133 let conn = graceful.watch(conn.into_owned());
134 tokio::spawn(async move {
135 if let Err(e) = conn.await {
136 xutil::warn_if_enable_debug(
137 &format!("h2c connection error: {e}"),
138 );
139 }
140 });
141 }
142 _ = shutdown_rx.changed() => break,
143 }
144 }
145
146 tokio::time::timeout(Duration::from_secs(10), graceful.shutdown())
148 .await
149 .ok();
150
151 Ok(())
152 }
153}
154
155impl Server for XAxumServer {
156 async fn run(&self) -> Result<(), XOneError> {
157 let (addr, enable_banner, use_http2) = resolve_config();
159 let addr = self.addr.unwrap_or(addr);
160 let enable_banner = self.enable_banner.unwrap_or(enable_banner);
161 let use_http2 = self.use_http2.unwrap_or(use_http2);
162
163 if enable_banner {
164 banner::print_banner();
165 }
166
167 let listener = tokio::net::TcpListener::bind(addr)
168 .await
169 .map_err(|e| XOneError::Server(format!("bind {addr} failed: {e}")))?;
170
171 xutil::info_if_enable_debug(&format!("axum server listening on {addr}"));
172
173 let shutdown_rx = self.shutdown_tx.subscribe();
174
175 if use_http2 {
176 self.run_h2c(listener, shutdown_rx).await
177 } else {
178 self.run_http1(listener, shutdown_rx).await
179 }
180 }
181
182 async fn stop(&self) -> Result<(), XOneError> {
183 let _ = self.shutdown_tx.send(true);
184 Ok(())
185 }
186}
187
188fn resolve_config() -> (SocketAddr, bool, bool) {
190 let c = config::load_config();
191 let addr = parse_addr(&c.host, c.port);
192 (addr, c.enable_banner, c.use_http2)
193}
194
195fn parse_addr(host: &str, port: u16) -> SocketAddr {
197 format!("{host}:{port}").parse().unwrap_or_else(|e| {
198 xutil::warn_if_enable_debug(&format!(
199 "parse addr [{host}:{port}] failed, fallback to 0.0.0.0:{port}, err=[{e}]"
200 ));
201 SocketAddr::from(([0, 0, 0, 0], port))
202 })
203}