spring_web/
lib.rs

1//! [![spring-rs](https://img.shields.io/github/stars/spring-rs/spring-rs)](https://spring-rs.github.io/docs/plugins/spring-web)
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6/// spring-web config
7pub mod config;
8/// spring-web defined error
9pub mod error;
10/// axum extract
11pub mod extractor;
12/// axum route handler
13pub mod handler;
14pub mod middleware;
15#[cfg(feature = "openapi")]
16pub mod openapi;
17/// RFC 7807 Problem Details for HTTP APIs
18pub mod problem_details;
19
20pub trait HttpStatusCode {
21    fn status_code(&self) -> axum::http::StatusCode;
22}
23
24pub use spring_macros::ProblemDetails;
25
26#[cfg(feature = "socket_io")]
27pub use { socketioxide, rmpv };
28
29pub use axum;
30pub use spring::async_trait;
31use spring::signal;
32/////////////////web-macros/////////////////////
33/// To use these Procedural Macros, you need to add `spring-web` dependency
34pub use spring_macros::middlewares;
35pub use spring_macros::nest;
36
37// route macros
38pub use spring_macros::delete;
39pub use spring_macros::get;
40pub use spring_macros::head;
41pub use spring_macros::options;
42pub use spring_macros::patch;
43pub use spring_macros::post;
44pub use spring_macros::put;
45pub use spring_macros::route;
46pub use spring_macros::routes;
47pub use spring_macros::trace;
48
49/// SocketIO macros
50#[cfg(feature = "socket_io")]
51pub use spring_macros::on_connection;
52#[cfg(feature = "socket_io")]
53pub use spring_macros::on_disconnect;
54#[cfg(feature = "socket_io")]
55pub use spring_macros::on_fallback;
56#[cfg(feature = "socket_io")]
57pub use spring_macros::subscribe_message;
58
59/// OpenAPI macros
60#[cfg(feature = "openapi")]
61pub use spring_macros::api_route;
62#[cfg(feature = "openapi")]
63pub use spring_macros::api_routes;
64#[cfg(feature = "openapi")]
65pub use spring_macros::delete_api;
66#[cfg(feature = "openapi")]
67pub use spring_macros::get_api;
68#[cfg(feature = "openapi")]
69pub use spring_macros::head_api;
70#[cfg(feature = "openapi")]
71pub use spring_macros::options_api;
72#[cfg(feature = "openapi")]
73pub use spring_macros::patch_api;
74#[cfg(feature = "openapi")]
75pub use spring_macros::post_api;
76#[cfg(feature = "openapi")]
77pub use spring_macros::put_api;
78#[cfg(feature = "openapi")]
79pub use spring_macros::trace_api;
80
81/// axum::routing::MethodFilter re-export
82pub use axum::routing::MethodFilter;
83
84/// Router with AppState
85#[cfg(not(feature = "openapi"))]
86pub type Router = axum::Router;
87/// MethodRouter with AppState
88pub use axum::routing::MethodRouter;
89
90#[cfg(feature = "openapi")]
91pub use aide;
92#[cfg(feature = "openapi")]
93pub use aide::openapi::OpenApi;
94#[cfg(feature = "openapi")]
95pub type Router = aide::axum::ApiRouter;
96#[cfg(feature = "openapi")]
97pub use aide::axum::routing::ApiMethodRouter;
98
99#[cfg(feature = "openapi")]
100use aide::transform::TransformOpenApi;
101
102use anyhow::Context;
103use axum::Extension;
104use config::ServerConfig;
105use config::WebConfig;
106use spring::plugin::component::ComponentRef;
107use spring::plugin::ComponentRegistry;
108use spring::plugin::MutableComponentRegistry;
109use spring::{
110    app::{App, AppBuilder},
111    config::ConfigRegistry,
112    error::Result,
113    plugin::Plugin,
114};
115use std::{net::SocketAddr, ops::Deref, sync::Arc};
116
117#[cfg(feature = "socket_io")]
118use config::SocketIOConfig;
119
120#[cfg(feature = "openapi")]
121use crate::config::OpenApiConfig;
122
123/// Routers collection
124#[cfg(feature = "openapi")]
125pub type Routers = Vec<aide::axum::ApiRouter>;
126#[cfg(not(feature = "openapi"))]
127pub type Routers = Vec<axum::Router>;
128
129/// Router layer function type
130///
131/// Used to add layers (middleware) to the router before the server starts.
132/// This enables plugins to dynamically register middleware layers.
133///
134/// # Example
135///
136/// ```rust,ignore
137/// use spring_web::{Router, LayerConfigurator};
138///
139/// // In your plugin's build method:
140/// app.add_router_layer(|router: Router| {
141///     router.layer(MyMiddlewareLayer::new())
142/// });
143/// ```
144pub type RouterLayer = Arc<dyn Fn(Router) -> Router + Send + Sync>;
145
146/// Collection of router layers
147pub type RouterLayers = Vec<RouterLayer>;
148
149/// Trait for adding layers to the web router
150pub trait LayerConfigurator {
151    /// Add a layer function that will be applied to the router before the server starts.
152    ///
153    /// Layers are applied in the order they are added.
154    ///
155    /// # Example
156    ///
157    /// ```rust,ignore
158    /// use spring_web::LayerConfigurator;
159    ///
160    /// app.add_router_layer(|router| {
161    ///     router.layer(MyAuthLayer::new(state))
162    /// });
163    /// ```
164    fn add_router_layer<F>(&mut self, layer: F) -> &mut Self
165    where
166        F: Fn(Router) -> Router + Send + Sync + 'static;
167}
168
169impl LayerConfigurator for AppBuilder {
170    fn add_router_layer<F>(&mut self, layer: F) -> &mut Self
171    where
172        F: Fn(Router) -> Router + Send + Sync + 'static,
173    {
174        if let Some(layers) = self.get_component_ref::<RouterLayers>() {
175            unsafe {
176                let raw_ptr = ComponentRef::into_raw(layers);
177                let layers = &mut *(raw_ptr as *mut RouterLayers);
178                layers.push(Arc::new(layer));
179            }
180            self
181        } else {
182            let layers: RouterLayers = vec![Arc::new(layer)];
183            self.add_component(layers)
184        }
185    }
186}
187
188/// OpenAPI
189#[cfg(feature = "openapi")]
190type OpenApiTransformer = fn(TransformOpenApi) -> TransformOpenApi;
191
192/// Web Configurator
193pub trait WebConfigurator {
194    /// add route to app registry
195    fn add_router(&mut self, router: Router) -> &mut Self;
196
197    /// Initialize OpenAPI Documents
198    #[cfg(feature = "openapi")]
199    fn openapi(&mut self, openapi: OpenApi) -> &mut Self;
200
201    /// Defining OpenAPI Documents
202    #[cfg(feature = "openapi")]
203    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self;
204}
205
206impl WebConfigurator for AppBuilder {
207    fn add_router(&mut self, router: Router) -> &mut Self {
208        if let Some(routers) = self.get_component_ref::<Routers>() {
209            unsafe {
210                let raw_ptr = ComponentRef::into_raw(routers);
211                let routers = &mut *(raw_ptr as *mut Routers);
212                routers.push(router);
213            }
214            self
215        } else {
216            self.add_component(vec![router])
217        }
218    }
219
220    /// Initialize OpenAPI Documents
221    #[cfg(feature = "openapi")]
222    fn openapi(&mut self, openapi: OpenApi) -> &mut Self {
223        self.add_component(openapi)
224    }
225
226    #[cfg(feature = "openapi")]
227    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self {
228        self.add_component(api_docs)
229    }
230}
231
232/// State of App
233#[derive(Clone)]
234pub struct AppState {
235    /// App Registry Ref
236    pub app: Arc<App>,
237}
238
239/// Web Plugin Definition
240pub struct WebPlugin;
241
242#[async_trait]
243impl Plugin for WebPlugin {
244    async fn build(&self, app: &mut AppBuilder) {
245        let config = app
246            .get_config::<WebConfig>()
247            .expect("web plugin config load failed");
248
249        #[cfg(feature = "socket_io")]
250        let socketio_config = app.get_config::<SocketIOConfig>().ok();
251
252        // 1. collect router
253        let routers = app.get_component_ref::<Routers>();
254        let mut router: Router = match routers {
255            Some(rs) => {
256                let mut router = Router::new();
257                for r in rs.deref().iter() {
258                    router = router.merge(r.to_owned());
259                }
260                router
261            }
262            None => Router::new(),
263        };
264        if let Some(middlewares) = config.middlewares {
265            router = crate::middleware::apply_middleware(router, middlewares);
266        }
267
268        #[cfg(feature = "socket_io")]
269        if let Some(socketio_config) = socketio_config {
270            router =  enable_socketio(socketio_config, app, router);
271        }
272
273        app.add_component(router);
274
275        let server_conf = config.server;
276        #[cfg(feature = "openapi")]
277        {
278            let openapi_conf = config.openapi;
279            app.add_component(openapi_conf.clone());
280        }
281
282        app.add_scheduler(move |app: Arc<App>| Box::new(Self::schedule(app, server_conf)));
283    }
284}
285
286impl WebPlugin {
287    async fn schedule(app: Arc<App>, config: ServerConfig) -> Result<String> {
288        let mut router = app.get_expect_component::<Router>();
289
290        // Apply custom router layers registered by plugins
291        // This is done in schedule() after all plugins have built,
292        // ensuring plugins that depend on other plugins can still register layers
293        if let Some(layers) = app.get_component_ref::<RouterLayers>() {
294            for layer_fn in layers.deref().iter() {
295                router = layer_fn(router);
296            }
297        }
298
299        // 2. bind tcp listener
300        let addr = SocketAddr::from((config.binding, config.port));
301        let listener = tokio::net::TcpListener::bind(addr)
302            .await
303            .with_context(|| format!("bind tcp listener failed:{addr}"))?;
304        tracing::info!("bind tcp listener: {addr}");
305
306        // 3. openapi
307        #[cfg(feature = "openapi")]
308        let router = {
309            let openapi_conf = app.get_expect_component::<OpenApiConfig>();
310            finish_openapi(&app, router, openapi_conf)
311        };
312
313        // 4. axum server
314        let mut router = router.layer(Extension(AppState { app }));
315
316        if !config.global_prefix.is_empty() {
317            router = axum::Router::new().nest(&config.global_prefix, router)
318        };
319
320
321        tracing::info!("axum server started");
322        if config.connect_info {
323            // with client connect info
324            let service = router.into_make_service_with_connect_info::<SocketAddr>();
325            let server = axum::serve(listener, service);
326            if config.graceful {
327                server
328                    .with_graceful_shutdown(signal::shutdown_signal("axum web server"))
329                    .await
330            } else {
331                server.await
332            }
333        } else {
334            let service = router.into_make_service();
335            let server = axum::serve(listener, service);
336            if config.graceful {
337                server
338                    .with_graceful_shutdown(signal::shutdown_signal("axum web server"))
339                    .await
340            } else {
341                server.await
342            }
343        }
344        .context("start axum server failed")?;
345
346        Ok("axum schedule finished".to_string())
347    }
348}
349
350#[cfg(feature = "openapi")]
351pub fn enable_openapi() {
352    aide::generate::on_error(|error| {
353        tracing::error!("{error}");
354    });
355    aide::generate::extract_schemas(false);
356}
357
358#[cfg(feature = "socket_io")]
359pub fn enable_socketio(socketio_config: SocketIOConfig, app: &mut AppBuilder, router: Router) -> Router {
360    tracing::info!("Configuring SocketIO with namespace: {}", socketio_config.default_namespace);
361    
362    let (layer, io) = socketioxide::SocketIo::builder()
363        .build_layer();
364    
365    let ns_path = socketio_config.default_namespace.clone();
366    let ns_path_for_closure = ns_path.clone();
367    io.ns(ns_path, move |socket: socketioxide::extract::SocketRef| {
368        use spring::tracing::info;
369        
370        info!(socket_id = ?socket.id, "New socket connected to namespace: {}", ns_path_for_closure);
371        
372        crate::handler::auto_socketio_setup(&socket);
373    });
374    
375    app.add_component(io);
376    router.layer(layer)
377}
378
379#[cfg(feature = "openapi")]
380fn finish_openapi(
381    app: &App,
382    router: aide::axum::ApiRouter,
383    openapi_conf: OpenApiConfig,
384) -> axum::Router {
385    let router = router.nest_api_service(&openapi_conf.doc_prefix, docs_routes(&openapi_conf));
386
387    let mut api = app.get_component::<OpenApi>().unwrap_or_else(|| OpenApi {
388        info: openapi_conf.info,
389        ..Default::default()
390    });
391
392    let router = if let Some(api_docs) = app.get_component::<OpenApiTransformer>() {
393        router.finish_api_with(&mut api, api_docs)
394    } else {
395        router.finish_api(&mut api)
396    };
397
398    router.layer(Extension(Arc::new(api)))
399}
400
401#[cfg(feature = "openapi")]
402pub fn docs_routes(OpenApiConfig { doc_prefix, info }: &OpenApiConfig) -> aide::axum::ApiRouter {
403    let router = aide::axum::ApiRouter::new();
404    let _openapi_path = &format!("{doc_prefix}/openapi.json");
405    let _doc_title = &info.title;
406
407    #[cfg(feature = "openapi-scalar")]
408    let router = router.route(
409        "/scalar",
410        aide::scalar::Scalar::new(_openapi_path)
411            .with_title(_doc_title)
412            .axum_route(),
413    );
414    #[cfg(feature = "openapi-redoc")]
415    let router = router.route(
416        "/redoc",
417        aide::redoc::Redoc::new(_openapi_path)
418            .with_title(_doc_title)
419            .axum_route(),
420    );
421    #[cfg(feature = "openapi-swagger")]
422    let router = router.route(
423        "/swagger",
424        aide::swagger::Swagger::new(_openapi_path)
425            .with_title(_doc_title)
426            .axum_route(),
427    );
428
429    router.route("/openapi.json", axum::routing::get(serve_docs))
430}
431
432#[cfg(feature = "openapi")]
433async fn serve_docs(Extension(api): Extension<Arc<OpenApi>>) -> impl aide::axum::IntoApiResponse {
434    axum::response::IntoResponse::into_response(axum::Json(api.as_ref()))
435}
436
437#[cfg(feature = "openapi")]
438pub fn default_transform<'a>(
439    path_item: aide::transform::TransformPathItem<'a>,
440) -> aide::transform::TransformPathItem<'a> {
441    path_item
442}