1#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6pub mod config;
8pub mod error;
10pub mod extractor;
12pub mod handler;
14pub mod middleware;
15#[cfg(feature = "openapi")]
16pub mod openapi;
17
18pub use axum;
19pub use spring::async_trait;
20pub use spring_macros::middlewares;
23pub use spring_macros::nest;
24
25pub use spring_macros::delete;
27pub use spring_macros::get;
28pub use spring_macros::head;
29pub use spring_macros::options;
30pub use spring_macros::patch;
31pub use spring_macros::post;
32pub use spring_macros::put;
33pub use spring_macros::route;
34pub use spring_macros::routes;
35pub use spring_macros::trace;
36
37#[cfg(feature = "openapi")]
38pub use spring_macros::api_route;
39#[cfg(feature = "openapi")]
40pub use spring_macros::api_routes;
41#[cfg(feature = "openapi")]
42pub use spring_macros::delete_api;
43#[cfg(feature = "openapi")]
44pub use spring_macros::get_api;
45#[cfg(feature = "openapi")]
46pub use spring_macros::head_api;
47#[cfg(feature = "openapi")]
48pub use spring_macros::options_api;
49#[cfg(feature = "openapi")]
50pub use spring_macros::patch_api;
51#[cfg(feature = "openapi")]
52pub use spring_macros::post_api;
53#[cfg(feature = "openapi")]
54pub use spring_macros::put_api;
55#[cfg(feature = "openapi")]
56pub use spring_macros::trace_api;
57
58pub use axum::routing::MethodFilter;
60
61#[cfg(not(feature = "openapi"))]
63pub type Router = axum::Router;
64pub use axum::routing::MethodRouter;
66
67#[cfg(feature = "openapi")]
68pub use aide;
69#[cfg(feature = "openapi")]
70pub use aide::openapi::OpenApi;
71#[cfg(feature = "openapi")]
72pub type Router = aide::axum::ApiRouter;
73#[cfg(feature = "openapi")]
74pub use aide::axum::routing::ApiMethodRouter;
75
76#[cfg(feature = "openapi")]
77use aide::transform::TransformOpenApi;
78
79use anyhow::Context;
80use axum::Extension;
81use config::ServerConfig;
82use config::WebConfig;
83use spring::plugin::component::ComponentRef;
84use spring::plugin::ComponentRegistry;
85use spring::plugin::MutableComponentRegistry;
86use spring::{
87 app::{App, AppBuilder},
88 config::ConfigRegistry,
89 error::Result,
90 plugin::Plugin,
91};
92use std::{net::SocketAddr, ops::Deref, sync::Arc};
93
94#[cfg(feature = "openapi")]
95use crate::config::OpenApiConfig;
96
97#[cfg(feature = "openapi")]
99pub type Routers = Vec<aide::axum::ApiRouter>;
100#[cfg(not(feature = "openapi"))]
101pub type Routers = Vec<axum::Router>;
102
103#[cfg(feature = "openapi")]
105type OpenApiTransformer = fn(TransformOpenApi) -> TransformOpenApi;
106
107pub trait WebConfigurator {
109 fn add_router(&mut self, router: Router) -> &mut Self;
111
112 #[cfg(feature = "openapi")]
114 fn openapi(&mut self, openapi: OpenApi) -> &mut Self;
115
116 #[cfg(feature = "openapi")]
118 fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self;
119}
120
121impl WebConfigurator for AppBuilder {
122 fn add_router(&mut self, router: Router) -> &mut Self {
123 if let Some(routers) = self.get_component_ref::<Routers>() {
124 unsafe {
125 let raw_ptr = ComponentRef::into_raw(routers);
126 let routers = &mut *(raw_ptr as *mut Routers);
127 routers.push(router);
128 }
129 self
130 } else {
131 self.add_component(vec![router])
132 }
133 }
134
135 #[cfg(feature = "openapi")]
137 fn openapi(&mut self, openapi: OpenApi) -> &mut Self {
138 self.add_component(openapi)
139 }
140
141 #[cfg(feature = "openapi")]
142 fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self {
143 self.add_component(api_docs)
144 }
145}
146
147#[derive(Clone)]
149pub struct AppState {
150 pub app: Arc<App>,
152}
153
154pub struct WebPlugin;
156
157#[async_trait]
158impl Plugin for WebPlugin {
159 async fn build(&self, app: &mut AppBuilder) {
160 let config = app
161 .get_config::<WebConfig>()
162 .expect("web plugin config load failed");
163
164 let routers = app.get_component_ref::<Routers>();
166 let mut router: Router = match routers {
167 Some(rs) => {
168 let mut router = Router::new();
169 for r in rs.deref().iter() {
170 router = router.merge(r.to_owned());
171 }
172 router
173 }
174 None => Router::new(),
175 };
176 if let Some(middlewares) = config.middlewares {
177 router = crate::middleware::apply_middleware(router, middlewares);
178 }
179
180 app.add_component(router);
181
182 let server_conf = config.server;
183 #[cfg(feature = "openapi")]
184 {
185 let openapi_conf = config.openapi;
186 app.add_component(openapi_conf.clone());
187 }
188
189 app.add_scheduler(move |app: Arc<App>| {
190 Box::new(Self::schedule(app, server_conf))
191 });
192 }
193}
194
195impl WebPlugin {
196 async fn schedule(
197 app: Arc<App>,
198 config: ServerConfig,
199 ) -> Result<String> {
200 let router = app.get_expect_component::<Router>();
201
202 let addr = SocketAddr::from((config.binding, config.port));
204 let listener = tokio::net::TcpListener::bind(addr)
205 .await
206 .with_context(|| format!("bind tcp listener failed:{addr}"))?;
207 tracing::info!("bind tcp listener: {addr}");
208
209 #[cfg(feature = "openapi")]
211 let router = {
212 let openapi_conf = app.get_expect_component::<OpenApiConfig>();
213 finish_openapi(&app, router, openapi_conf)
214 };
215
216 let router = router.layer(Extension(AppState { app }));
218
219 tracing::info!("axum server started");
220 if config.connect_info {
221 let service = router.into_make_service_with_connect_info::<SocketAddr>();
223 let server = axum::serve(listener, service);
224 if config.graceful {
225 server.with_graceful_shutdown(shutdown_signal()).await
226 } else {
227 server.await
228 }
229 } else {
230 let service = router.into_make_service();
231 let server = axum::serve(listener, service);
232 if config.graceful {
233 server.with_graceful_shutdown(shutdown_signal()).await
234 } else {
235 server.await
236 }
237 }
238 .context("start axum server failed")?;
239
240 Ok("axum schedule finished".to_string())
241 }
242}
243
244#[cfg(feature = "openapi")]
245pub fn enable_openapi() {
246 aide::generate::on_error(|error| {
247 tracing::error!("{error}");
248 });
249 aide::generate::extract_schemas(false);
250}
251
252#[cfg(feature = "openapi")]
253fn finish_openapi(
254 app: &App,
255 router: aide::axum::ApiRouter,
256 openapi_conf: OpenApiConfig,
257) -> axum::Router {
258 let router = router.nest_api_service(&openapi_conf.doc_prefix, docs_routes(&openapi_conf));
259
260 let mut api = app.get_component::<OpenApi>().unwrap_or_else(|| OpenApi {
261 info: openapi_conf.info,
262 ..Default::default()
263 });
264
265 let router = if let Some(api_docs) = app.get_component::<OpenApiTransformer>() {
266 router.finish_api_with(&mut api, api_docs)
267 } else {
268 router.finish_api(&mut api)
269 };
270
271 router.layer(Extension(Arc::new(api)))
272}
273
274#[cfg(feature = "openapi")]
275pub fn docs_routes(OpenApiConfig { doc_prefix, info }: &OpenApiConfig) -> aide::axum::ApiRouter {
276 let router = aide::axum::ApiRouter::new();
277 let _openapi_path = &format!("{doc_prefix}/openapi.json");
278 let _doc_title = &info.title;
279
280 #[cfg(feature = "openapi-scalar")]
281 let router = router.route(
282 "/scalar",
283 aide::scalar::Scalar::new(_openapi_path)
284 .with_title(_doc_title)
285 .axum_route(),
286 );
287 #[cfg(feature = "openapi-redoc")]
288 let router = router.route(
289 "/redoc",
290 aide::redoc::Redoc::new(_openapi_path)
291 .with_title(_doc_title)
292 .axum_route(),
293 );
294 #[cfg(feature = "openapi-swagger")]
295 let router = router.route(
296 "/swagger",
297 aide::swagger::Swagger::new(_openapi_path)
298 .with_title(_doc_title)
299 .axum_route(),
300 );
301
302 router.route("/openapi.json", axum::routing::get(serve_docs))
303}
304
305#[cfg(feature = "openapi")]
306async fn serve_docs(Extension(api): Extension<Arc<OpenApi>>) -> impl aide::axum::IntoApiResponse {
307 axum::response::IntoResponse::into_response(axum::Json(api.as_ref()))
308}
309
310#[cfg(feature = "openapi")]
311pub fn default_transform<'a>(
312 path_item: aide::transform::TransformPathItem<'a>,
313) -> aide::transform::TransformPathItem<'a> {
314 path_item
315}
316
317async fn shutdown_signal() {
318 let ctrl_c = async {
319 tokio::signal::ctrl_c()
320 .await
321 .expect("failed to install Ctrl+C handler");
322 };
323
324 #[cfg(unix)]
325 let terminate = async {
326 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
327 .expect("failed to install signal handler")
328 .recv()
329 .await;
330 };
331
332 #[cfg(not(unix))]
333 let terminate = std::future::pending::<()>();
334
335 tokio::select! {
336 _ = ctrl_c => {
337 tracing::info!("Received Ctrl+C signal, waiting for web server shutdown")
338 },
339 _ = terminate => {
340 tracing::info!("Received kill signal, waiting for web server shutdown")
341 },
342 }
343}