spring_web/
lib.rs

1//! [![spring-rs](https://img.shields.io/github/stars/spring-rs/spring-rs)](https://spring-rs.github.io/docs/plugins/spring-web)
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6/// spring-web config
7pub mod config;
8/// spring-web defined error
9pub mod error;
10/// axum extract
11pub mod extractor;
12/// axum route handler
13pub mod handler;
14pub mod middleware;
15
16pub use axum;
17pub use spring::async_trait;
18/////////////////web-macros/////////////////////
19/// To use these Procedural Macros, you need to add `spring-web` dependency
20pub use spring_macros::delete;
21pub use spring_macros::get;
22pub use spring_macros::head;
23pub use spring_macros::middlewares;
24pub use spring_macros::nest;
25pub use spring_macros::options;
26pub use spring_macros::patch;
27pub use spring_macros::post;
28pub use spring_macros::put;
29pub use spring_macros::route;
30pub use spring_macros::routes;
31pub use spring_macros::trace;
32
33/// axum::routing::MethodFilter re-export
34pub use axum::routing::MethodFilter;
35/// MethodRouter with AppState
36pub use axum::routing::MethodRouter;
37/// Router with AppState
38pub use axum::Router;
39
40use anyhow::Context;
41use axum::Extension;
42use config::ServerConfig;
43use config::WebConfig;
44use spring::plugin::component::ComponentRef;
45use spring::plugin::ComponentRegistry;
46use spring::plugin::MutableComponentRegistry;
47use spring::{
48    app::{App, AppBuilder},
49    config::ConfigRegistry,
50    error::Result,
51    plugin::Plugin,
52};
53use std::{net::SocketAddr, ops::Deref, sync::Arc};
54
55/// Routers collection
56pub type Routers = Vec<Router>;
57
58/// Web Configurator
59pub trait WebConfigurator {
60    /// add route to app registry
61    fn add_router(&mut self, router: Router) -> &mut Self;
62}
63
64impl WebConfigurator for AppBuilder {
65    fn add_router(&mut self, router: Router) -> &mut Self {
66        if let Some(routers) = self.get_component_ref::<Routers>() {
67            unsafe {
68                let raw_ptr = ComponentRef::into_raw(routers);
69                let routers = &mut *(raw_ptr as *mut Routers);
70                routers.push(router);
71            }
72            self
73        } else {
74            self.add_component(vec![router])
75        }
76    }
77}
78
79/// State of App
80#[derive(Clone)]
81pub struct AppState {
82    /// App Registry Ref
83    pub app: Arc<App>,
84}
85
86/// Web Plugin Definition
87pub struct WebPlugin;
88
89#[async_trait]
90impl Plugin for WebPlugin {
91    async fn build(&self, app: &mut AppBuilder) {
92        let config = app
93            .get_config::<WebConfig>()
94            .expect("web plugin config load failed");
95
96        // 1. collect router
97        let routers = app.get_component_ref::<Routers>();
98        let mut router: Router = match routers {
99            Some(rs) => {
100                let mut router = Router::new();
101                for r in rs.deref().iter() {
102                    router = router.merge(r.to_owned());
103                }
104                router
105            }
106            None => Router::new(),
107        };
108        if let Some(middlewares) = config.middlewares {
109            router = crate::middleware::apply_middleware(router, middlewares);
110        }
111
112        let server_conf = config.server;
113
114        app.add_scheduler(move |app: Arc<App>| Box::new(Self::schedule(router, app, server_conf)));
115    }
116}
117
118impl WebPlugin {
119    async fn schedule(router: Router, app: Arc<App>, config: ServerConfig) -> Result<String> {
120        // 2. bind tcp listener
121        let addr = SocketAddr::from((config.binding, config.port));
122        let listener = tokio::net::TcpListener::bind(addr)
123            .await
124            .with_context(|| format!("bind tcp listener failed:{}", addr))?;
125        tracing::info!("bind tcp listener: {}", addr);
126
127        // 3. axum server
128        let router = router.layer(Extension(AppState { app }));
129
130        tracing::info!("axum server started");
131        if config.connect_info {
132            // with client connect info
133            let service = router.into_make_service_with_connect_info::<SocketAddr>();
134            let server = axum::serve(listener, service);
135            if config.graceful {
136                server.with_graceful_shutdown(shutdown_signal()).await
137            } else {
138                server.await
139            }
140        } else {
141            let service = router.into_make_service();
142            let server = axum::serve(listener, service);
143            if config.graceful {
144                server.with_graceful_shutdown(shutdown_signal()).await
145            } else {
146                server.await
147            }
148        }
149        .context("start axum server failed")?;
150
151        Ok("axum schedule finished".to_string())
152    }
153}
154
155async fn shutdown_signal() {
156    let ctrl_c = async {
157        tokio::signal::ctrl_c()
158            .await
159            .expect("failed to install Ctrl+C handler");
160    };
161
162    #[cfg(unix)]
163    let terminate = async {
164        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
165            .expect("failed to install signal handler")
166            .recv()
167            .await;
168    };
169
170    #[cfg(not(unix))]
171    let terminate = std::future::pending::<()>();
172
173    tokio::select! {
174        _ = ctrl_c => {
175            tracing::info!("Received Ctrl+C signal, waiting for web server shutdown")
176        },
177        _ = terminate => {
178            tracing::info!("Received kill signal, waiting for web server shutdown")
179        },
180    }
181}