spring_web/
lib.rs

1//! [![spring-rs](https://img.shields.io/github/stars/spring-rs/spring-rs)](https://spring-rs.github.io/docs/plugins/spring-web)
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6/// spring-web config
7pub mod config;
8/// spring-web defined error
9pub mod error;
10/// axum extract
11pub mod extractor;
12/// axum route handler
13pub mod handler;
14pub mod middleware;
15#[cfg(feature = "openapi")]
16pub mod openapi;
17
18#[cfg(feature = "socket_io")]
19pub use { socketioxide, rmpv };
20
21pub use axum;
22pub use spring::async_trait;
23use spring::signal;
24/////////////////web-macros/////////////////////
25/// To use these Procedural Macros, you need to add `spring-web` dependency
26pub use spring_macros::middlewares;
27pub use spring_macros::nest;
28
29// route macros
30pub use spring_macros::delete;
31pub use spring_macros::get;
32pub use spring_macros::head;
33pub use spring_macros::options;
34pub use spring_macros::patch;
35pub use spring_macros::post;
36pub use spring_macros::put;
37pub use spring_macros::route;
38pub use spring_macros::routes;
39pub use spring_macros::trace;
40
41/// SocketIO macros
42#[cfg(feature = "socket_io")]
43pub use spring_macros::on_connection;
44#[cfg(feature = "socket_io")]
45pub use spring_macros::on_disconnect;
46#[cfg(feature = "socket_io")]
47pub use spring_macros::on_fallback;
48#[cfg(feature = "socket_io")]
49pub use spring_macros::subscribe_message;
50
51/// OpenAPI macros
52#[cfg(feature = "openapi")]
53pub use spring_macros::api_route;
54#[cfg(feature = "openapi")]
55pub use spring_macros::api_routes;
56#[cfg(feature = "openapi")]
57pub use spring_macros::delete_api;
58#[cfg(feature = "openapi")]
59pub use spring_macros::get_api;
60#[cfg(feature = "openapi")]
61pub use spring_macros::head_api;
62#[cfg(feature = "openapi")]
63pub use spring_macros::options_api;
64#[cfg(feature = "openapi")]
65pub use spring_macros::patch_api;
66#[cfg(feature = "openapi")]
67pub use spring_macros::post_api;
68#[cfg(feature = "openapi")]
69pub use spring_macros::put_api;
70#[cfg(feature = "openapi")]
71pub use spring_macros::trace_api;
72
73/// axum::routing::MethodFilter re-export
74pub use axum::routing::MethodFilter;
75
76/// Router with AppState
77#[cfg(not(feature = "openapi"))]
78pub type Router = axum::Router;
79/// MethodRouter with AppState
80pub use axum::routing::MethodRouter;
81
82#[cfg(feature = "openapi")]
83pub use aide;
84#[cfg(feature = "openapi")]
85pub use aide::openapi::OpenApi;
86#[cfg(feature = "openapi")]
87pub type Router = aide::axum::ApiRouter;
88#[cfg(feature = "openapi")]
89pub use aide::axum::routing::ApiMethodRouter;
90
91#[cfg(feature = "openapi")]
92use aide::transform::TransformOpenApi;
93
94use anyhow::Context;
95use axum::Extension;
96use config::ServerConfig;
97use config::WebConfig;
98use spring::plugin::component::ComponentRef;
99use spring::plugin::ComponentRegistry;
100use spring::plugin::MutableComponentRegistry;
101use spring::{
102    app::{App, AppBuilder},
103    config::ConfigRegistry,
104    error::Result,
105    plugin::Plugin,
106};
107use std::{net::SocketAddr, ops::Deref, sync::Arc};
108
109#[cfg(feature = "socket_io")]
110use config::SocketIOConfig;
111
112#[cfg(feature = "openapi")]
113use crate::config::OpenApiConfig;
114
115/// Routers collection
116#[cfg(feature = "openapi")]
117pub type Routers = Vec<aide::axum::ApiRouter>;
118#[cfg(not(feature = "openapi"))]
119pub type Routers = Vec<axum::Router>;
120
121/// OpenAPI
122#[cfg(feature = "openapi")]
123type OpenApiTransformer = fn(TransformOpenApi) -> TransformOpenApi;
124
125/// Web Configurator
126pub trait WebConfigurator {
127    /// add route to app registry
128    fn add_router(&mut self, router: Router) -> &mut Self;
129
130    /// Initialize OpenAPI Documents
131    #[cfg(feature = "openapi")]
132    fn openapi(&mut self, openapi: OpenApi) -> &mut Self;
133
134    /// Defining OpenAPI Documents
135    #[cfg(feature = "openapi")]
136    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self;
137}
138
139impl WebConfigurator for AppBuilder {
140    fn add_router(&mut self, router: Router) -> &mut Self {
141        if let Some(routers) = self.get_component_ref::<Routers>() {
142            unsafe {
143                let raw_ptr = ComponentRef::into_raw(routers);
144                let routers = &mut *(raw_ptr as *mut Routers);
145                routers.push(router);
146            }
147            self
148        } else {
149            self.add_component(vec![router])
150        }
151    }
152
153    /// Initialize OpenAPI Documents
154    #[cfg(feature = "openapi")]
155    fn openapi(&mut self, openapi: OpenApi) -> &mut Self {
156        self.add_component(openapi)
157    }
158
159    #[cfg(feature = "openapi")]
160    fn api_docs(&mut self, api_docs: OpenApiTransformer) -> &mut Self {
161        self.add_component(api_docs)
162    }
163}
164
165/// State of App
166#[derive(Clone)]
167pub struct AppState {
168    /// App Registry Ref
169    pub app: Arc<App>,
170}
171
172/// Web Plugin Definition
173pub struct WebPlugin;
174
175#[async_trait]
176impl Plugin for WebPlugin {
177    async fn build(&self, app: &mut AppBuilder) {
178        let config = app
179            .get_config::<WebConfig>()
180            .expect("web plugin config load failed");
181
182        #[cfg(feature = "socket_io")]
183        let socketio_config = app.get_config::<SocketIOConfig>().ok();
184
185        // 1. collect router
186        let routers = app.get_component_ref::<Routers>();
187        let mut router: Router = match routers {
188            Some(rs) => {
189                let mut router = Router::new();
190                for r in rs.deref().iter() {
191                    router = router.merge(r.to_owned());
192                }
193                router
194            }
195            None => Router::new(),
196        };
197        if let Some(middlewares) = config.middlewares {
198            router = crate::middleware::apply_middleware(router, middlewares);
199        }
200
201        #[cfg(feature = "socket_io")]
202        if let Some(socketio_config) = socketio_config {
203            router =  enable_socketio(socketio_config, app, router);
204        }
205
206        app.add_component(router);
207
208        let server_conf = config.server;
209        #[cfg(feature = "openapi")]
210        {
211            let openapi_conf = config.openapi;
212            app.add_component(openapi_conf.clone());
213        }
214
215        app.add_scheduler(move |app: Arc<App>| Box::new(Self::schedule(app, server_conf)));
216    }
217}
218
219impl WebPlugin {
220    async fn schedule(app: Arc<App>, config: ServerConfig) -> Result<String> {
221        let router = app.get_expect_component::<Router>();
222
223        // 2. bind tcp listener
224        let addr = SocketAddr::from((config.binding, config.port));
225        let listener = tokio::net::TcpListener::bind(addr)
226            .await
227            .with_context(|| format!("bind tcp listener failed:{addr}"))?;
228        tracing::info!("bind tcp listener: {addr}");
229
230        // 3. openapi
231        #[cfg(feature = "openapi")]
232        let router = {
233            let openapi_conf = app.get_expect_component::<OpenApiConfig>();
234            finish_openapi(&app, router, openapi_conf)
235        };
236
237        // 4. axum server
238        let router = router.layer(Extension(AppState { app }));
239
240        tracing::info!("axum server started");
241        if config.connect_info {
242            // with client connect info
243            let service = router.into_make_service_with_connect_info::<SocketAddr>();
244            let server = axum::serve(listener, service);
245            if config.graceful {
246                server
247                    .with_graceful_shutdown(signal::shutdown_signal())
248                    .await
249            } else {
250                server.await
251            }
252        } else {
253            let service = router.into_make_service();
254            let server = axum::serve(listener, service);
255            if config.graceful {
256                server
257                    .with_graceful_shutdown(signal::shutdown_signal())
258                    .await
259            } else {
260                server.await
261            }
262        }
263        .context("start axum server failed")?;
264
265        Ok("axum schedule finished".to_string())
266    }
267}
268
269#[cfg(feature = "openapi")]
270pub fn enable_openapi() {
271    aide::generate::on_error(|error| {
272        tracing::error!("{error}");
273    });
274    aide::generate::extract_schemas(false);
275}
276
277#[cfg(feature = "socket_io")]
278pub fn enable_socketio(socketio_config: SocketIOConfig, app: &mut AppBuilder, router: Router) -> Router {
279    tracing::info!("Configuring SocketIO with namespace: {}", socketio_config.default_namespace);
280    
281    let (layer, io) = socketioxide::SocketIo::builder()
282        .build_layer();
283    
284    let ns_path = socketio_config.default_namespace.clone();
285    let ns_path_for_closure = ns_path.clone();
286    io.ns(ns_path, move |socket: socketioxide::extract::SocketRef| {
287        use spring::tracing::info;
288        
289        info!(socket_id = ?socket.id, "New socket connected to namespace: {}", ns_path_for_closure);
290        
291        crate::handler::auto_socketio_setup(&socket);
292    });
293    
294    app.add_component(io);
295    router.layer(layer)
296}
297
298#[cfg(feature = "openapi")]
299fn finish_openapi(
300    app: &App,
301    router: aide::axum::ApiRouter,
302    openapi_conf: OpenApiConfig,
303) -> axum::Router {
304    let router = router.nest_api_service(&openapi_conf.doc_prefix, docs_routes(&openapi_conf));
305
306    let mut api = app.get_component::<OpenApi>().unwrap_or_else(|| OpenApi {
307        info: openapi_conf.info,
308        ..Default::default()
309    });
310
311    let router = if let Some(api_docs) = app.get_component::<OpenApiTransformer>() {
312        router.finish_api_with(&mut api, api_docs)
313    } else {
314        router.finish_api(&mut api)
315    };
316
317    router.layer(Extension(Arc::new(api)))
318}
319
320#[cfg(feature = "openapi")]
321pub fn docs_routes(OpenApiConfig { doc_prefix, info }: &OpenApiConfig) -> aide::axum::ApiRouter {
322    let router = aide::axum::ApiRouter::new();
323    let _openapi_path = &format!("{doc_prefix}/openapi.json");
324    let _doc_title = &info.title;
325
326    #[cfg(feature = "openapi-scalar")]
327    let router = router.route(
328        "/scalar",
329        aide::scalar::Scalar::new(_openapi_path)
330            .with_title(_doc_title)
331            .axum_route(),
332    );
333    #[cfg(feature = "openapi-redoc")]
334    let router = router.route(
335        "/redoc",
336        aide::redoc::Redoc::new(_openapi_path)
337            .with_title(_doc_title)
338            .axum_route(),
339    );
340    #[cfg(feature = "openapi-swagger")]
341    let router = router.route(
342        "/swagger",
343        aide::swagger::Swagger::new(_openapi_path)
344            .with_title(_doc_title)
345            .axum_route(),
346    );
347
348    router.route("/openapi.json", axum::routing::get(serve_docs))
349}
350
351#[cfg(feature = "openapi")]
352async fn serve_docs(Extension(api): Extension<Arc<OpenApi>>) -> impl aide::axum::IntoApiResponse {
353    axum::response::IntoResponse::into_response(axum::Json(api.as_ref()))
354}
355
356#[cfg(feature = "openapi")]
357pub fn default_transform<'a>(
358    path_item: aide::transform::TransformPathItem<'a>,
359) -> aide::transform::TransformPathItem<'a> {
360    path_item
361}