1use super::{banner, config};
7use crate::error::XOneError;
8use crate::xserver::Server;
9use crate::xutil;
10use axum::Router;
11use axum::serve::ListenerExt;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio::sync::watch;
15
16pub struct XAxumServer {
32 router: Router,
33 addr: Option<SocketAddr>,
35 enable_banner: Option<bool>,
37 use_http2: Option<bool>,
39 shutdown_tx: watch::Sender<bool>,
40}
41
42impl XAxumServer {
43 pub(crate) fn new(
45 router: Router,
46 addr: Option<SocketAddr>,
47 enable_banner: Option<bool>,
48 use_http2: Option<bool>,
49 ) -> Self {
50 let (shutdown_tx, _) = watch::channel(false);
51 Self {
52 router,
53 addr,
54 enable_banner,
55 use_http2,
56 shutdown_tx,
57 }
58 }
59
60 pub fn addr(&self) -> SocketAddr {
64 self.addr.unwrap_or_else(|| resolve_config().0)
65 }
66
67 pub fn use_http2(&self) -> bool {
71 self.use_http2.unwrap_or_else(|| resolve_config().2)
72 }
73
74 pub fn into_router(self) -> Router {
76 self.router
77 }
78
79 async fn run_http1(
81 &self,
82 listener: tokio::net::TcpListener,
83 mut shutdown_rx: watch::Receiver<bool>,
84 ) -> Result<(), XOneError> {
85 let listener = listener.tap_io(|tcp_stream| {
86 let _ = tcp_stream.set_nodelay(true);
87 });
88 axum::serve(listener, self.router.clone())
89 .with_graceful_shutdown(async move {
90 let _ = shutdown_rx.changed().await;
91 })
92 .await
93 .map_err(|e| XOneError::Server(format!("server error: {e}")))?;
94 Ok(())
95 }
96
97 async fn run_h2c(
99 &self,
100 listener: tokio::net::TcpListener,
101 mut shutdown_rx: watch::Receiver<bool>,
102 ) -> Result<(), XOneError> {
103 use hyper::body::Incoming;
104 use hyper_util::rt::{TokioExecutor, TokioIo};
105 use hyper_util::server::conn::auto;
106 use hyper_util::server::graceful::GracefulShutdown;
107 use tower_service::Service;
108
109 let graceful = GracefulShutdown::new();
110 let builder = auto::Builder::new(TokioExecutor::new());
112
113 loop {
114 tokio::select! {
115 result = listener.accept() => {
116 let (socket, _remote_addr) = result
117 .map_err(|e| XOneError::Server(format!("accept failed: {e}")))?;
118 let _ = socket.set_nodelay(true);
119
120 let tower_service = self.router.clone();
121
122 let hyper_service = hyper::service::service_fn(
123 move |request: axum::extract::Request<Incoming>| {
124 tower_service.clone().call(request)
125 },
126 );
127
128 let conn = builder
129 .serve_connection_with_upgrades(
130 TokioIo::new(socket),
131 hyper_service,
132 );
133
134 let conn = graceful.watch(conn.into_owned());
135 tokio::spawn(async move {
136 if let Err(e) = conn.await {
137 xutil::warn_if_enable_debug(
138 &format!("h2c connection error: {e}"),
139 );
140 }
141 });
142 }
143 _ = shutdown_rx.changed() => break,
144 }
145 }
146
147 tokio::time::timeout(Duration::from_secs(10), graceful.shutdown())
149 .await
150 .ok();
151
152 Ok(())
153 }
154}
155
156impl Server for XAxumServer {
157 async fn run(&self) -> Result<(), XOneError> {
158 let (addr, enable_banner, use_http2) = resolve_config();
160 let addr = self.addr.unwrap_or(addr);
161 let enable_banner = self.enable_banner.unwrap_or(enable_banner);
162 let use_http2 = self.use_http2.unwrap_or(use_http2);
163
164 if enable_banner {
165 banner::print_banner();
166 }
167
168 let listener = tokio::net::TcpListener::bind(addr)
169 .await
170 .map_err(|e| XOneError::Server(format!("bind {addr} failed: {e}")))?;
171
172 xutil::info_if_enable_debug(&format!("axum server listening on {addr}"));
173
174 let shutdown_rx = self.shutdown_tx.subscribe();
175
176 if use_http2 {
177 self.run_h2c(listener, shutdown_rx).await
178 } else {
179 self.run_http1(listener, shutdown_rx).await
180 }
181 }
182
183 async fn stop(&self) -> Result<(), XOneError> {
184 let _ = self.shutdown_tx.send(true);
185 Ok(())
186 }
187}
188
189fn resolve_config() -> (SocketAddr, bool, bool) {
191 let c = config::load_config();
192 let addr = parse_addr(&c.host, c.port);
193 (addr, c.enable_banner, c.use_http2)
194}
195
196fn parse_addr(host: &str, port: u16) -> SocketAddr {
198 format!("{host}:{port}").parse().unwrap_or_else(|e| {
199 xutil::warn_if_enable_debug(&format!(
200 "parse addr [{host}:{port}] failed, fallback to 0.0.0.0:{port}, err=[{e}]"
201 ));
202 SocketAddr::from(([0, 0, 0, 0], port))
203 })
204}