spacegate_ext_axum/
lib.rs

1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5pub use axum;
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::{BoxError, Router};
9use tokio::sync::RwLock;
10use tokio::task::JoinHandle;
11use tokio_util::sync::CancellationToken;
12/// Default port for the global server
13const GLOBAL_SERVER_PORT: u16 = 9876;
14/// Default host for the global server
15const GLOBAL_SERVER_HOST: IpAddr = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
16/// Default bind to [::]:9876
17const GLOBAL_SERVER_BIND: SocketAddr = SocketAddr::new(GLOBAL_SERVER_HOST, GLOBAL_SERVER_PORT);
18#[derive(Debug)]
19struct AxumServerInner {
20    pub bind: SocketAddr,
21    pub router: Router,
22    pub cancel_token: CancellationToken,
23    handle: Option<JoinHandle<Result<(), std::io::Error>>>,
24}
25
26impl Default for AxumServerInner {
27    fn default() -> Self {
28        Self {
29            bind: GLOBAL_SERVER_BIND,
30            router: Default::default(),
31            cancel_token: Default::default(),
32            handle: Default::default(),
33        }
34    }
35}
36
37/// Global axum http server for spacegate and its plugins.
38///
39/// # Usage
40/// ```
41/// # use spacegate_ext_axum::GlobalAxumServer;
42/// let server = GlobalAxumServer::default();
43/// ```
44#[derive(Debug, Clone)]
45pub struct GlobalAxumServer(Arc<RwLock<AxumServerInner>>);
46
47impl Default for GlobalAxumServer {
48    fn default() -> Self {
49        Self(AxumServerInner::global())
50    }
51}
52
53impl GlobalAxumServer {
54    /// Set the bind address for the server. If the server is already running, new bind address will take effect after restart.
55    pub async fn set_bind<A>(&self, socket_addr: A)
56    where
57        A: Into<SocketAddr>,
58    {
59        let socket_addr = socket_addr.into();
60        let mut wg = self.0.write().await;
61        wg.bind = socket_addr;
62    }
63
64    /// Get the bind address of the server.
65    pub async fn get_bind(&self) -> SocketAddr {
66        let wg = self.0.read().await;
67        wg.bind
68    }
69
70    /// Set the cancellation token for the server.
71    pub async fn set_cancellation(&self, token: CancellationToken) {
72        let mut wg = self.0.write().await;
73        wg.cancel_token = token;
74    }
75
76    /// Modify the router with the given closure.
77    pub async fn modify_router<M>(&self, modify: M)
78    where
79        M: FnOnce(Router) -> Router,
80    {
81        let mut wg = self.0.write().await;
82        let mut swap_out = Router::default();
83        std::mem::swap(&mut swap_out, &mut wg.router);
84        wg.router = (modify)(swap_out)
85    }
86
87    /// Start the server, if the server is already running, it will be restarted.
88    pub async fn start(&self) -> Result<(), std::io::Error> {
89        let mut wg = self.0.write().await;
90        wg.start().await
91    }
92
93    /// Shutdown the server.
94    pub async fn shutdown(&self) -> Result<(), std::io::Error> {
95        let mut wg = self.0.write().await;
96        wg.shutdown().await
97    }
98}
99
100impl AxumServerInner {
101    pub fn global() -> Arc<RwLock<AxumServerInner>> {
102        static GLOBAL: OnceLock<Arc<RwLock<AxumServerInner>>> = OnceLock::new();
103        GLOBAL.get_or_init(Default::default).clone()
104    }
105    pub async fn start(&mut self) -> Result<(), std::io::Error> {
106        let _shutdown_result = self.shutdown().await;
107        let tcp_listener = tokio::net::TcpListener::bind(self.bind).await?;
108        let cancel = self.cancel_token.clone();
109        let router = self.router.clone();
110        let task = tokio::spawn(async move { axum::serve(tcp_listener, router).with_graceful_shutdown(cancel.cancelled_owned()).await });
111        self.handle = Some(task);
112        Ok(())
113    }
114    pub async fn shutdown(&mut self) -> Result<(), std::io::Error> {
115        if let Some(handle) = self.handle.take() {
116            self.cancel_token.cancel();
117            handle.await.expect("tokio task join error")
118        } else {
119            Ok(())
120        }
121    }
122}
123
124pub struct InternalError {
125    reason: BoxError,
126}
127
128impl IntoResponse for InternalError {
129    fn into_response(self) -> Response {
130        let body = axum::body::Body::from(format!("Internal error: {}", self.reason));
131        Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(body).unwrap()
132    }
133}
134
135impl<E> From<E> for InternalError
136where
137    E: std::error::Error + Send + Sync + 'static,
138{
139    fn from(e: E) -> Self {
140        Self { reason: Box::new(e) }
141    }
142}