spring_web/
lib.rs

1//! [![spring-rs](https://img.shields.io/github/stars/spring-rs/spring-rs)](https://spring-rs.github.io/docs/plugins/spring-web)
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://spring-rs.github.io/favicon.ico")]
4#![doc(html_logo_url = "https://spring-rs.github.io/logo.svg")]
5
6/// spring-web config
7pub mod config;
8/// spring-web defined error
9pub mod error;
10/// axum extract
11pub mod extractor;
12/// axum route handler
13pub mod handler;
14pub mod middleware;
15
16pub use axum;
17pub use spring::async_trait;
18/////////////////web-macros/////////////////////
19/// To use these Procedural Macros, you need to add `spring-web` dependency
20pub use spring_macros::delete;
21pub use spring_macros::get;
22pub use spring_macros::head;
23pub use spring_macros::nest;
24pub use spring_macros::options;
25pub use spring_macros::patch;
26pub use spring_macros::post;
27pub use spring_macros::put;
28pub use spring_macros::route;
29pub use spring_macros::routes;
30pub use spring_macros::trace;
31
32/// axum::routing::MethodFilter re-export
33pub use axum::routing::MethodFilter;
34/// MethodRouter with AppState
35pub use axum::routing::MethodRouter;
36/// Router with AppState
37pub use axum::Router;
38
39use anyhow::Context;
40use axum::Extension;
41use config::ServerConfig;
42use config::WebConfig;
43use spring::plugin::component::ComponentRef;
44use spring::plugin::ComponentRegistry;
45use spring::plugin::MutableComponentRegistry;
46use spring::{
47    app::{App, AppBuilder},
48    config::ConfigRegistry,
49    error::Result,
50    plugin::Plugin,
51};
52use std::{net::SocketAddr, ops::Deref, sync::Arc};
53
54/// Routers collection
55pub type Routers = Vec<Router>;
56
57/// Web Configurator
58pub trait WebConfigurator {
59    /// add route to app registry
60    fn add_router(&mut self, router: Router) -> &mut Self;
61}
62
63impl WebConfigurator for AppBuilder {
64    fn add_router(&mut self, router: Router) -> &mut Self {
65        if let Some(routers) = self.get_component_ref::<Routers>() {
66            unsafe {
67                let raw_ptr = ComponentRef::into_raw(routers);
68                let routers = &mut *(raw_ptr as *mut Routers);
69                routers.push(router);
70            }
71            self
72        } else {
73            self.add_component(vec![router])
74        }
75    }
76}
77
78/// State of App
79#[derive(Clone)]
80pub struct AppState {
81    /// App Registry Ref
82    pub app: Arc<App>,
83}
84
85/// Web Plugin Definition
86pub struct WebPlugin;
87
88#[async_trait]
89impl Plugin for WebPlugin {
90    async fn build(&self, app: &mut AppBuilder) {
91        let config = app
92            .get_config::<WebConfig>()
93            .expect("web plugin config load failed");
94
95        // 1. collect router
96        let routers = app.get_component_ref::<Routers>();
97        let mut router: Router = match routers {
98            Some(rs) => {
99                let mut router = Router::new();
100                for r in rs.deref().iter() {
101                    router = router.merge(r.to_owned());
102                }
103                router
104            }
105            None => Router::new(),
106        };
107        if let Some(middlewares) = config.middlewares {
108            router = crate::middleware::apply_middleware(router, middlewares);
109        }
110
111        let server_conf = config.server;
112
113        app.add_scheduler(move |app: Arc<App>| Box::new(Self::schedule(router, app, server_conf)));
114    }
115}
116
117impl WebPlugin {
118    async fn schedule(router: Router, app: Arc<App>, config: ServerConfig) -> Result<String> {
119        // 2. bind tcp listener
120        let addr = SocketAddr::from((config.binding, config.port));
121        let listener = tokio::net::TcpListener::bind(addr)
122            .await
123            .with_context(|| format!("bind tcp listener failed:{}", addr))?;
124        tracing::info!("bind tcp listener: {}", addr);
125
126        // 3. axum server
127        let router = router.layer(Extension(AppState { app }));
128
129        tracing::info!("axum server started");
130        if config.connect_info {
131            // with client connect info
132            let service = router.into_make_service_with_connect_info::<SocketAddr>();
133            let server = axum::serve(listener, service);
134            if config.graceful {
135                server.with_graceful_shutdown(shutdown_signal()).await
136            } else {
137                server.await
138            }
139        } else {
140            let service = router.into_make_service();
141            let server = axum::serve(listener, service);
142            if config.graceful {
143                server.with_graceful_shutdown(shutdown_signal()).await
144            } else {
145                server.await
146            }
147        }
148        .context("start axum server failed")?;
149
150        Ok("axum schedule finished".to_string())
151    }
152}
153
154async fn shutdown_signal() {
155    let ctrl_c = async {
156        tokio::signal::ctrl_c()
157            .await
158            .expect("failed to install Ctrl+C handler");
159    };
160
161    #[cfg(unix)]
162    let terminate = async {
163        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
164            .expect("failed to install signal handler")
165            .recv()
166            .await;
167    };
168
169    #[cfg(not(unix))]
170    let terminate = std::future::pending::<()>();
171
172    tokio::select! {
173        _ = ctrl_c => {
174            tracing::info!("Received Ctrl+C signal, waiting for web server shutdown")
175        },
176        _ = terminate => {
177            tracing::info!("Received kill signal, waiting for web server shutdown")
178        },
179    }
180}