spring_web/
lib.rs

1//! [![spring-rs](https://img.shields.io/github/stars/spring-rs/spring-rs)](https://spring-rs.github.io/docs/plugins/spring-web)
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6/// spring-web config
7pub mod config;
8/// spring-web defined error
9pub mod error;
10/// axum extract
11pub mod extractor;
12/// axum route handler
13pub mod handler;
14pub mod middleware;
15#[cfg(feature = "openapi")]
16pub mod openapi;
17
18pub use axum;
19pub use spring::async_trait;
20/////////////////web-macros/////////////////////
21/// To use these Procedural Macros, you need to add `spring-web` dependency
22pub use spring_macros::middlewares;
23pub use spring_macros::nest;
24
25// route macros
26pub use spring_macros::delete;
27pub use spring_macros::get;
28pub use spring_macros::head;
29pub use spring_macros::options;
30pub use spring_macros::patch;
31pub use spring_macros::post;
32pub use spring_macros::put;
33pub use spring_macros::route;
34pub use spring_macros::routes;
35pub use spring_macros::trace;
36
37#[cfg(feature = "openapi")]
38pub use spring_macros::api_route;
39#[cfg(feature = "openapi")]
40pub use spring_macros::api_routes;
41#[cfg(feature = "openapi")]
42pub use spring_macros::delete_api;
43#[cfg(feature = "openapi")]
44pub use spring_macros::get_api;
45#[cfg(feature = "openapi")]
46pub use spring_macros::head_api;
47#[cfg(feature = "openapi")]
48pub use spring_macros::options_api;
49#[cfg(feature = "openapi")]
50pub use spring_macros::patch_api;
51#[cfg(feature = "openapi")]
52pub use spring_macros::post_api;
53#[cfg(feature = "openapi")]
54pub use spring_macros::put_api;
55#[cfg(feature = "openapi")]
56pub use spring_macros::trace_api;
57
58/// axum::routing::MethodFilter re-export
59pub use axum::routing::MethodFilter;
60
61/// Router with AppState
62#[cfg(not(feature = "openapi"))]
63pub type Router = axum::Router;
64/// MethodRouter with AppState
65pub use axum::routing::MethodRouter;
66
67#[cfg(feature = "openapi")]
68pub use aide;
69#[cfg(feature = "openapi")]
70pub use aide::openapi::OpenApi;
71#[cfg(feature = "openapi")]
72pub type Router = aide::axum::ApiRouter;
73#[cfg(feature = "openapi")]
74pub use aide::axum::routing::ApiMethodRouter;
75
76#[cfg(feature = "openapi")]
77use aide::transform::TransformOpenApi;
78
79use anyhow::Context;
80use axum::Extension;
81use config::ServerConfig;
82use config::WebConfig;
83use spring::plugin::component::ComponentRef;
84use spring::plugin::ComponentRegistry;
85use spring::plugin::MutableComponentRegistry;
86use spring::{
87    app::{App, AppBuilder},
88    config::ConfigRegistry,
89    error::Result,
90    plugin::Plugin,
91};
92use std::{net::SocketAddr, ops::Deref, sync::Arc};
93
94#[cfg(feature = "openapi")]
95use crate::config::OpenApiConfig;
96
97/// Routers collection
98#[cfg(feature = "openapi")]
99pub type Routers = Vec<aide::axum::ApiRouter>;
100#[cfg(not(feature = "openapi"))]
101pub type Routers = Vec<axum::Router>;
102
103/// OpenAPI
104#[cfg(feature = "openapi")]
105type OpenApiTransformer = fn(TransformOpenApi) -> TransformOpenApi;
106
107/// Web Configurator
108pub trait WebConfigurator {
109    /// add route to app registry
110    fn add_router(&mut self, router: Router) -> &mut Self;
111
112    /// Initialize OpenAPI Documents
113    #[cfg(feature = "openapi")]
114    fn openapi(&mut self, openapi: OpenApi) -> &mut Self;
115
116    /// Defining OpenAPI Documents
117    #[cfg(feature = "openapi")]
118    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self;
119}
120
121impl WebConfigurator for AppBuilder {
122    fn add_router(&mut self, router: Router) -> &mut Self {
123        if let Some(routers) = self.get_component_ref::<Routers>() {
124            unsafe {
125                let raw_ptr = ComponentRef::into_raw(routers);
126                let routers = &mut *(raw_ptr as *mut Routers);
127                routers.push(router);
128            }
129            self
130        } else {
131            self.add_component(vec![router])
132        }
133    }
134
135    /// Initialize OpenAPI Documents
136    #[cfg(feature = "openapi")]
137    fn openapi(&mut self, openapi: OpenApi) -> &mut Self {
138        self.add_component(openapi)
139    }
140
141    #[cfg(feature = "openapi")]
142    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self {
143        self.add_component(api_docs)
144    }
145}
146
147/// State of App
148#[derive(Clone)]
149pub struct AppState {
150    /// App Registry Ref
151    pub app: Arc<App>,
152}
153
154/// Web Plugin Definition
155pub struct WebPlugin;
156
157#[async_trait]
158impl Plugin for WebPlugin {
159    async fn build(&self, app: &mut AppBuilder) {
160        let config = app
161            .get_config::<WebConfig>()
162            .expect("web plugin config load failed");
163
164        // 1. collect router
165        let routers = app.get_component_ref::<Routers>();
166        let mut router: Router = match routers {
167            Some(rs) => {
168                let mut router = Router::new();
169                for r in rs.deref().iter() {
170                    router = router.merge(r.to_owned());
171                }
172                router
173            }
174            None => Router::new(),
175        };
176        if let Some(middlewares) = config.middlewares {
177            router = crate::middleware::apply_middleware(router, middlewares);
178        }
179
180        app.add_component(router);
181
182        let server_conf = config.server;
183        #[cfg(feature = "openapi")]
184        {
185            let openapi_conf = config.openapi;
186            app.add_component(openapi_conf.clone());
187        }
188
189        app.add_scheduler(move |app: Arc<App>| {
190            Box::new(Self::schedule(app, server_conf))
191        });
192    }
193}
194
195impl WebPlugin {
196    async fn schedule(
197        app: Arc<App>,
198        config: ServerConfig,
199    ) -> Result<String> {
200        let router = app.get_expect_component::<Router>();
201
202        // 2. bind tcp listener
203        let addr = SocketAddr::from((config.binding, config.port));
204        let listener = tokio::net::TcpListener::bind(addr)
205            .await
206            .with_context(|| format!("bind tcp listener failed:{addr}"))?;
207        tracing::info!("bind tcp listener: {addr}");
208
209        // 3. openapi
210        #[cfg(feature = "openapi")]
211        let router = {
212            let openapi_conf = app.get_expect_component::<OpenApiConfig>();
213            finish_openapi(&app, router, openapi_conf)
214        };
215
216        // 4. axum server
217        let router = router.layer(Extension(AppState { app }));
218
219        tracing::info!("axum server started");
220        if config.connect_info {
221            // with client connect info
222            let service = router.into_make_service_with_connect_info::<SocketAddr>();
223            let server = axum::serve(listener, service);
224            if config.graceful {
225                server.with_graceful_shutdown(shutdown_signal()).await
226            } else {
227                server.await
228            }
229        } else {
230            let service = router.into_make_service();
231            let server = axum::serve(listener, service);
232            if config.graceful {
233                server.with_graceful_shutdown(shutdown_signal()).await
234            } else {
235                server.await
236            }
237        }
238        .context("start axum server failed")?;
239
240        Ok("axum schedule finished".to_string())
241    }
242}
243
244#[cfg(feature = "openapi")]
245pub fn enable_openapi() {
246    aide::generate::on_error(|error| {
247        tracing::error!("{error}");
248    });
249    aide::generate::extract_schemas(false);
250}
251
252#[cfg(feature = "openapi")]
253fn finish_openapi(
254    app: &App,
255    router: aide::axum::ApiRouter,
256    openapi_conf: OpenApiConfig,
257) -> axum::Router {
258    let router = router.nest_api_service(&openapi_conf.doc_prefix, docs_routes(&openapi_conf));
259
260    let mut api = app.get_component::<OpenApi>().unwrap_or_else(|| OpenApi {
261        info: openapi_conf.info,
262        ..Default::default()
263    });
264
265    let router = if let Some(api_docs) = app.get_component::<OpenApiTransformer>() {
266        router.finish_api_with(&mut api, api_docs)
267    } else {
268        router.finish_api(&mut api)
269    };
270
271    router.layer(Extension(Arc::new(api)))
272}
273
274#[cfg(feature = "openapi")]
275pub fn docs_routes(OpenApiConfig { doc_prefix, info }: &OpenApiConfig) -> aide::axum::ApiRouter {
276    let router = aide::axum::ApiRouter::new();
277    let _openapi_path = &format!("{doc_prefix}/openapi.json");
278    let _doc_title = &info.title;
279
280    #[cfg(feature = "openapi-scalar")]
281    let router = router.route(
282        "/scalar",
283        aide::scalar::Scalar::new(_openapi_path)
284            .with_title(_doc_title)
285            .axum_route(),
286    );
287    #[cfg(feature = "openapi-redoc")]
288    let router = router.route(
289        "/redoc",
290        aide::redoc::Redoc::new(_openapi_path)
291            .with_title(_doc_title)
292            .axum_route(),
293    );
294    #[cfg(feature = "openapi-swagger")]
295    let router = router.route(
296        "/swagger",
297        aide::swagger::Swagger::new(_openapi_path)
298            .with_title(_doc_title)
299            .axum_route(),
300    );
301
302    router.route("/openapi.json", axum::routing::get(serve_docs))
303}
304
305#[cfg(feature = "openapi")]
306async fn serve_docs(Extension(api): Extension<Arc<OpenApi>>) -> impl aide::axum::IntoApiResponse {
307    axum::response::IntoResponse::into_response(axum::Json(api.as_ref()))
308}
309
310#[cfg(feature = "openapi")]
311pub fn default_transform<'a>(
312    path_item: aide::transform::TransformPathItem<'a>,
313) -> aide::transform::TransformPathItem<'a> {
314    path_item
315}
316
317async fn shutdown_signal() {
318    let ctrl_c = async {
319        tokio::signal::ctrl_c()
320            .await
321            .expect("failed to install Ctrl+C handler");
322    };
323
324    #[cfg(unix)]
325    let terminate = async {
326        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
327            .expect("failed to install signal handler")
328            .recv()
329            .await;
330    };
331
332    #[cfg(not(unix))]
333    let terminate = std::future::pending::<()>();
334
335    tokio::select! {
336        _ = ctrl_c => {
337            tracing::info!("Received Ctrl+C signal, waiting for web server shutdown")
338        },
339        _ = terminate => {
340            tracing::info!("Received kill signal, waiting for web server shutdown")
341        },
342    }
343}