spring_axum/
lib.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    fs,
5    net::SocketAddr,
6    sync::{Arc, RwLock},
7};
8
9use anyhow::Result;
10// use of axum items will rely on `pub use` exports below
11use tower_http::{cors::CorsLayer, trace::TraceLayer};
12use tower_http::add_extension::AddExtensionLayer;
13use tower::{Layer, Service};
14use axum::body::Body;
15use axum::http::{Request, Response};
16use futures::future::BoxFuture;
17use axum::response::IntoResponse;
18use serde::de::DeserializeOwned;
19use tracing_subscriber::{fmt, EnvFilter};
20use inventory;
21use std::any::TypeId as StdTypeId;
22use once_cell::sync::OnceCell;
23use std::time::{Duration, Instant};
24use axum::response::Html;
25use regex::Regex;
26#[cfg(feature = "devtools")]
27use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
28#[cfg(feature = "devtools")]
29use tokio::sync::broadcast;
30#[cfg(feature = "devtools")]
31use tokio_stream::wrappers::BroadcastStream;
32#[cfg(feature = "devtools")]
33use tokio_stream::StreamExt;
34#[cfg(feature = "devtools")]
35use notify::{Config as NotifyConfig, RecommendedWatcher, RecursiveMode};
36#[cfg(feature = "devtools")]
37use notify::Watcher;
38#[cfg(feature = "swagger")]
39use utoipa::openapi::{OpenApiBuilder, InfoBuilder, PathsBuilder};
40#[cfg(feature = "swagger")]
41static OPENAPI: OnceCell<utoipa::openapi::OpenApi> = OnceCell::new();
42#[cfg(feature = "swagger")]
43static OPENAPI_JSON: OnceCell<serde_json::Value> = OnceCell::new();
44
45// Re-export commonly used axum items for users of the framework
46pub use axum::Router;
47pub use axum::routing::{delete, get, post, put};
48pub use axum::extract::Path;
49pub use axum::extract::Json;
50pub use axum::extract::rejection::JsonRejection;
51pub use axum::http::StatusCode;
52pub use axum::body::Body as AxumBody;
53pub use axum::http::{Request as AxumRequest, Response as AxumResponse};
54// Re-export macros
55pub use spring_axum_macros::*;
56
57// ---------------- Convenience Macros for MyBatis-like integration -----------------
58// Initialize global MyBatis XML registry once at startup.
59// Usage: spring_axum::mybatis_init!("resources/mybatis");
60#[macro_export]
61macro_rules! mybatis_init {
62    ($dir:expr) => {{
63        if let Err(e) = ::spring_axum_mybatis::init_global_from_dir_once($dir) {
64            ::tracing::error!(error = %e, "failed to init mybatis from dir");
65            panic!("MyBatis init failed: {}", e);
66        }
67    }};
68}
69
70// Get global registry (panics with clear message if not initialized).
71#[macro_export]
72macro_rules! mybatis_registry {
73    () => {{
74        ::spring_axum_mybatis::global_registry()
75            .expect("MyBatis global registry not initialized; call mybatis_init!(...) early")
76    }};
77}
78
79// Prepare and execute a statement by id with JSON params, mapping errors to AppError::Internal.
80// Usage: spring_axum::mybatis_exec!(exec, "UserMapper.saveProfile", json!({...}));
81#[macro_export]
82macro_rules! mybatis_exec {
83    ($executor:expr, $stmt_id:expr, $params_json:expr) => {{
84        let reg = $crate::mybatis_registry!();
85        let (sql, ordered) = reg
86            .prepare_json($stmt_id, &$params_json)
87            .ok_or_else(|| $crate::AppError::Internal("statement not found".into()))?;
88        $executor
89            .execute(&sql, &ordered)
90            .map_err(|e| $crate::AppError::Internal(format!("execute error: {}", e)))?;
91    }};
92}
93
94// Define a mapper method that automatically maps to `<namespace>.<method_name>` and builds params from args.
95// Example:
96// impl UserMapper {
97//     spring_axum::sql_method!(save_profile(username: &str, email: &str); namespace = "UserMapper");
98// }
99#[macro_export]
100macro_rules! sql_method {
101    ($fn_name:ident ( $($arg:ident : $typ:ty),* $(,)? ) ; namespace = $ns:literal) => {
102        pub fn $fn_name(&self, $($arg: $typ),*) -> ::spring_axum::AppResult<()> {
103            let params = ::serde_json::json!({ $(stringify!($arg): $arg),* });
104            let exec = ::spring_axum_mybatis::NoopExecutor::default();
105            ::spring_axum::mybatis_exec!(exec, concat!($ns, ".", stringify!($fn_name)), params);
106            Ok(())
107        }
108    };
109}
110
111// Macro for auto discovery builder chain, to simplify main.rs
112// Usage: spring_axum::auto_discover!().run().await
113#[macro_export]
114macro_rules! auto_discover {
115    () => {{
116        ::spring_axum::SpringApp::new()
117            .with_discovered_components()
118            .with_discovered_interceptors()
119            .with_discovered_advices()
120            .with_discovered_controllers()
121    }};
122}
123
124// ---------------- Unified Error -----------------
125#[derive(Debug, Clone, serde::Serialize)]
126pub struct ErrorResponse { pub error: String }
127
128#[derive(Debug, Clone, serde::Serialize)]
129pub struct ValidationErrorsBody {
130    pub errors: HashMap<String, Vec<String>>,
131}
132
133#[derive(Debug, Clone)]
134pub enum AppError {
135    BadRequest(String),
136    NotFound(String),
137    Internal(String),
138    Validation(ValidationErrorsBody),
139    Unauthorized(String),
140    Forbidden(String),
141}
142
143pub type AppResult<T> = std::result::Result<T, AppError>;
144
145impl IntoResponse for AppError {
146    fn into_response(self) -> AxumResponse<AxumBody> {
147        match self {
148            AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, axum::Json(ErrorResponse { error: msg })).into_response(),
149            AppError::NotFound(msg) => (StatusCode::NOT_FOUND, axum::Json(ErrorResponse { error: msg })).into_response(),
150            AppError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(ErrorResponse { error: msg })).into_response(),
151            AppError::Validation(body) => (StatusCode::BAD_REQUEST, axum::Json(body)).into_response(),
152            AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, axum::Json(ErrorResponse { error: msg })).into_response(),
153            AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, axum::Json(ErrorResponse { error: msg })).into_response(),
154        }
155    }
156}
157
158impl From<JsonRejection> for AppError {
159    fn from(err: JsonRejection) -> Self { AppError::BadRequest(err.to_string()) }
160}
161
162#[cfg(feature = "validator")]
163impl From<validator::ValidationErrors> for AppError {
164    fn from(errs: validator::ValidationErrors) -> Self {
165        let mut map: HashMap<String, Vec<String>> = HashMap::new();
166        for (field, field_errs) in errs.field_errors() {
167            let messages: Vec<String> = field_errs
168                .iter()
169                .map(|e| {
170                    if let Some(msg) = e.message.as_ref() {
171                        msg.to_string()
172                    } else {
173                        // fallback to the error code if message missing
174                        e.code.to_string()
175                    }
176                })
177                .collect();
178            map.insert(field.to_string(), messages);
179        }
180        AppError::Validation(ValidationErrorsBody { errors: map })
181    }
182}
183
184// ---------------- Validated Extractors (feature: validator) -----------------
185#[cfg(feature = "validator")]
186pub struct ValidatedJson<T>(pub T);
187
188#[cfg(feature = "validator")]
189impl<T, S> axum::extract::FromRequest<S> for ValidatedJson<T>
190where
191    T: DeserializeOwned + validator::Validate,
192    S: Send + Sync,
193{
194    type Rejection = AppError;
195    fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
196        let fut = axum::extract::Json::<T>::from_request(req, state);
197        async move {
198            match fut.await {
199                Ok(axum::extract::Json(v)) => {
200                    if let Err(e) = v.validate() { return Err(AppError::from(e)); }
201                    Ok(ValidatedJson(v))
202                }
203                Err(rej) => Err(AppError::from(rej)),
204            }
205        }
206    }
207}
208
209#[cfg(feature = "validator")]
210pub struct ValidatedQuery<T>(pub T);
211
212#[cfg(feature = "validator")]
213impl<T, S> axum::extract::FromRequestParts<S> for ValidatedQuery<T>
214where
215    T: DeserializeOwned + validator::Validate,
216    S: Send + Sync,
217{
218    type Rejection = AppError;
219    fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
220        let fut = axum::extract::Query::<T>::from_request_parts(parts, state);
221        async move {
222            match fut.await {
223                Ok(axum::extract::Query(v)) => {
224                    if let Err(e) = v.validate() { return Err(AppError::from(e)); }
225                    Ok(ValidatedQuery(v))
226                }
227                Err(rej) => Err(AppError::BadRequest(rej.to_string())),
228            }
229        }
230    }
231}
232
233#[cfg(feature = "validator")]
234pub struct ValidatedForm<T>(pub T);
235
236#[cfg(feature = "validator")]
237impl<T, S> axum::extract::FromRequest<S> for ValidatedForm<T>
238where
239    T: DeserializeOwned + validator::Validate,
240    S: Send + Sync,
241{
242    type Rejection = AppError;
243    fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
244        let fut = axum::extract::Form::<T>::from_request(req, state);
245        async move {
246            match fut.await {
247                Ok(axum::extract::Form(v)) => {
248                    if let Err(e) = v.validate() { return Err(AppError::from(e)); }
249                    Ok(ValidatedForm(v))
250                }
251                Err(rej) => Err(AppError::BadRequest(rej.to_string())),
252            }
253        }
254    }
255}
256
257// Optional stream-friendly JSON validation using axum-extra's JsonDeserializer
258#[cfg(all(feature = "validator", feature = "axum_extra_json"))]
259pub struct ValidatedJsonStream<T>(pub T);
260
261#[cfg(all(feature = "validator", feature = "axum_extra_json"))]
262impl<T, S> axum::extract::FromRequest<S> for ValidatedJsonStream<T>
263where
264    T: DeserializeOwned + validator::Validate + 'static,
265    S: Send + Sync,
266{
267    type Rejection = AppError;
268    fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
269        let fut = axum_extra::extract::JsonDeserializer::<T>::from_request(req, state);
270        async move {
271            match fut.await {
272                Ok(deserializer) => match deserializer.deserialize() {
273                    Ok(v) => {
274                        if let Err(e) = v.validate() { return Err(AppError::from(e)); }
275                        Ok(ValidatedJsonStream(v))
276                    }
277                    Err(rej) => Err(AppError::BadRequest(rej.to_string())),
278                },
279                Err(rej) => Err(AppError::BadRequest(rej.to_string())),
280            }
281        }
282    }
283}
284
285#[derive(Clone, serde::Deserialize, serde::Serialize, Debug)]
286pub struct ServerConfig {
287    pub host: String,
288    pub port: u16,
289    #[serde(default)]
290    pub enable_cors: bool,
291}
292
293impl Default for ServerConfig {
294    fn default() -> Self {
295        Self {
296            host: "127.0.0.1".into(),
297            port: 3000,
298            enable_cors: true,
299        }
300    }
301}
302
303#[derive(Clone, serde::Deserialize, serde::Serialize, Debug, Default)]
304pub struct AppConfig {
305    #[serde(default)]
306    pub server: ServerConfig,
307    #[serde(default)]
308    pub devtools: DevtoolsConfig,
309    #[serde(default)]
310    pub tx_advice: TxAdviceConfig,
311}
312
313#[derive(Clone, serde::Deserialize, serde::Serialize, Debug)]
314pub struct DevtoolsConfig {
315    #[serde(default)]
316    pub enabled: bool,
317    #[serde(default)]
318    pub watch: Vec<String>,
319}
320
321impl Default for DevtoolsConfig {
322    fn default() -> Self {
323        Self { enabled: false, watch: vec!["src".into(), "resources".into(), "templates".into()] }
324    }
325}
326
327#[derive(Clone, serde::Deserialize, serde::Serialize, Debug, Default)]
328pub struct TxAdviceConfig {
329    #[serde(default)]
330    pub enabled: bool,
331    #[serde(default)]
332    pub pointcut: Option<String>,
333}
334
335fn load_config() -> AppConfig {
336    let base_path = "application.yaml";
337    let mut config = match fs::read_to_string(base_path) {
338        Ok(s) => serde_yaml::from_str::<AppConfig>(&s).unwrap_or_default(),
339        Err(_) => AppConfig::default(),
340    };
341
342    if let Ok(profile) = std::env::var("SPRING_PROFILE") {
343        let prof_path = format!("application-{}.yaml", profile);
344        if let Ok(s) = fs::read_to_string(&prof_path) {
345            if let Ok(overlay) = serde_yaml::from_str::<AppConfig>(&s) {
346                // simple overlay: replace server entirely if provided
347                config.server = overlay.server;
348            }
349        }
350    }
351
352    // environment overrides
353    if let Ok(host) = std::env::var("SERVER_HOST") {
354        config.server.host = host;
355    }
356    if let Ok(port) = std::env::var("SERVER_PORT") {
357        if let Ok(p) = port.parse::<u16>() {
358            config.server.port = p;
359        }
360    }
361    if let Ok(cors) = std::env::var("SERVER_ENABLE_CORS") {
362        config.server.enable_cors = matches!(cors.as_str(), "1" | "true" | "True" | "TRUE");
363    }
364    // devtools environment overrides
365    if let Ok(enabled) = std::env::var("DEVTOOLS_ENABLED") {
366        config.devtools.enabled = matches!(enabled.as_str(), "1" | "true" | "True" | "TRUE");
367    }
368    if let Ok(watch) = std::env::var("DEVTOOLS_WATCH") {
369        let list: Vec<String> = watch
370            .split([',', ';', '|'])
371            .map(|s| s.trim().to_string())
372            .filter(|s| !s.is_empty())
373            .collect();
374        if !list.is_empty() { config.devtools.watch = list; }
375    }
376    config
377}
378
379fn validate_config(cfg: &AppConfig) {
380    if cfg.server.host.trim().is_empty() {
381        tracing::warn!(target = "spring_axum", "server.host is empty; defaulting may cause bind errors");
382    }
383    if cfg.server.port == 0 {
384        tracing::warn!(target = "spring_axum", port = cfg.server.port, "server.port must be in 1..=65535");
385    }
386    if cfg.server.host.starts_with("http://") || cfg.server.host.starts_with("https://") {
387        tracing::warn!(target = "spring_axum", host = %cfg.server.host, "server.host should be a hostname/ip, not URL");
388    }
389    #[cfg(feature = "devtools")]
390    if cfg.devtools.enabled {
391        for p in &cfg.devtools.watch {
392            if !std::path::Path::new(p).exists() {
393                tracing::info!(target = "spring_axum", path = %p, "devtools watch path not found (will still attempt watching)");
394            }
395        }
396    }
397}
398
399pub trait Controller {
400    fn routes(&self) -> Router;
401}
402
403pub struct SpringApp {
404    router: Router,
405    config: AppConfig,
406    ctx: Arc<ApplicationContext>,
407}
408
409impl SpringApp {
410    pub fn new() -> Self {
411        let config = load_config();
412        validate_config(&config);
413        let router = Router::new().route("/actuator/health", get(|| async { "OK" }));
414        let ctx = Arc::new(ApplicationContext::new());
415        init_defaults_once();
416        // mount simple cache metrics
417        let router = router.merge(cache_metrics_router());
418        #[cfg(feature = "swagger")]
419        let router = router.merge(swagger_routes());
420        #[cfg(feature = "devtools")]
421        let router = if config.devtools.enabled { router.merge(devtools_router()) } else { router };
422        #[cfg(feature = "devtools")]
423        if config.devtools.enabled { start_devtools_watcher(&config); }
424        Self { router, config, ctx }
425    }
426
427    pub fn with_controller(mut self, controller: impl Controller) -> Self {
428        self.router = self.router.merge(controller.routes());
429        self
430    }
431
432    pub fn with_router(mut self, router: Router) -> Self {
433        self.router = self.router.merge(router);
434        self
435    }
436
437    pub async fn run(self) -> Result<()> {
438        init_tracing();
439
440        // Auto initialize optional integrations so users don't need to hand-write setup macros
441        auto_init_integrations();
442
443        print_banner_if_any();
444
445        tracing::info!(target = "spring_axum", host = %self.config.server.host, port = self.config.server.port, cors = self.config.server.enable_cors, devtools = self.config.devtools.enabled, "Startup configuration");
446
447        // 发布应用启动事件,供监听器使用
448        publish_event(AppStarting { config: self.config.clone() }, &self.ctx);
449
450        let mut router = self
451            .router
452            .layer(TraceLayer::new_for_http())
453            .layer(AddExtensionLayer::new(self.ctx.clone()));
454        if self.config.server.enable_cors {
455            router = router.layer(CorsLayer::permissive());
456        }
457
458        let addr: SocketAddr = format!("{}:{}", self.config.server.host, self.config.server.port)
459            .parse()
460            .expect("invalid server address");
461        tracing::info!(target: "spring_axum", "Listening on http://{}", addr);
462        axum::serve(tokio::net::TcpListener::bind(addr).await?, router).await?;
463        Ok(())
464    }
465}
466
467fn init_tracing() {
468    let _ = fmt()
469        .with_env_filter(
470            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
471        )
472        .with_target(false)
473        .compact()
474        .try_init();
475}
476
477// ---------------- DI Container -----------------
478pub struct ApplicationContext {
479    inner: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
480}
481
482impl ApplicationContext {
483    pub fn new() -> Self {
484        Self { inner: RwLock::new(HashMap::new()) }
485    }
486
487    pub fn register<T: Send + Sync + 'static>(&self, value: T) {
488        let arc: Arc<T> = Arc::new(value);
489        self.inner
490            .write()
491            .unwrap()
492            .insert(TypeId::of::<T>(), Box::new(arc));
493    }
494
495    pub fn register_dynamic(&self, type_id: StdTypeId, arc_any: Box<dyn Any + Send + Sync>) {
496        self.inner.write().unwrap().insert(type_id, arc_any);
497    }
498
499    pub fn resolve<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
500        self.inner
501            .read()
502            .unwrap()
503            .get(&TypeId::of::<T>())
504            .and_then(|b| b.downcast_ref::<Arc<T>>().cloned())
505    }
506}
507
508// ---------------- Transactions -----------------
509pub trait TransactionManager: Send + Sync + 'static {
510    fn begin(&self) -> BoxFuture<'static, Result<(), AppError>>;
511    fn commit(&self) -> BoxFuture<'static, Result<(), AppError>>;
512    fn rollback(&self) -> BoxFuture<'static, Result<(), AppError>>;
513}
514
515#[derive(Clone, Default)]
516pub struct NoopTransactionManager;
517
518impl TransactionManager for NoopTransactionManager {
519    fn begin(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
520    fn commit(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
521    fn rollback(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
522}
523
524static GLOBAL_TX_MANAGER: OnceCell<Arc<dyn TransactionManager>> = OnceCell::new();
525
526fn init_defaults_once() {
527    // Initialize default managers only once
528    let _ = GLOBAL_TX_MANAGER.set(Arc::new(NoopTransactionManager::default()));
529    let _ = GLOBAL_CACHE.set(InMemoryCache::default());
530}
531
532// Try to auto-initialize integrations so users don't have to wire macros in main.rs
533fn auto_init_integrations() {
534    // MyBatis: initialize if standard directory exists and not yet initialized
535    let mybatis_dir = std::path::Path::new("resources/mybatis");
536    if mybatis_dir.exists() {
537        if spring_axum_mybatis::global_registry().is_none() {
538            match spring_axum_mybatis::init_global_from_dir_once("resources/mybatis") {
539                Ok(_) => tracing::info!(target = "spring_axum", path = %mybatis_dir.display(), "MyBatis registry initialized"),
540                Err(e) => tracing::warn!(target = "spring_axum", error = %e, "Failed to init MyBatis registry"),
541            }
542        }
543    } else {
544        tracing::debug!(target = "spring_axum", "resources/mybatis not found; skipping MyBatis init");
545    }
546
547    // sqlx pool: initialize lazily from DATABASE_URL if feature enabled and pool not set yet
548    #[cfg(feature = "sqlx_postgres")]
549    {
550        if GLOBAL_SQLX_POOL.get().is_none() {
551            match std::env::var("DATABASE_URL") {
552                Ok(url) => {
553                    match sqlx::postgres::PgPoolOptions::new().connect_lazy(&url) {
554                        Ok(pool) => {
555                            let _ = GLOBAL_SQLX_POOL.set(pool);
556                            tracing::info!(target = "spring_axum", "sqlx pool initialized from env DATABASE_URL");
557                        }
558                        Err(e) => tracing::warn!(target = "spring_axum", error = %e, "Failed to lazy connect sqlx pool"),
559                    }
560                }
561                Err(_) => tracing::debug!(target = "spring_axum", "DATABASE_URL not set; skipping sqlx pool init"),
562            }
563        }
564    }
565}
566
567pub fn set_transaction_manager<T: TransactionManager>(mgr: Arc<T>) {
568    let _ = GLOBAL_TX_MANAGER.set(mgr);
569}
570
571pub async fn transaction<F, Fut, T>(f: F) -> AppResult<T>
572where
573    F: FnOnce() -> Fut,
574    Fut: std::future::Future<Output = AppResult<T>>,
575{
576    let mgr = GLOBAL_TX_MANAGER.get().cloned().unwrap_or_else(|| Arc::new(NoopTransactionManager));
577    if let Err(e) = mgr.begin().await { return Err(e); }
578    match f().await {
579        Ok(val) => {
580            if let Err(e) = mgr.commit().await { return Err(e); }
581            Ok(val)
582        }
583        Err(err) => {
584            let _ = mgr.rollback().await;
585            Err(err)
586        }
587    }
588}
589
590// ---------------- In-Memory Cache -----------------
591#[derive(Clone, Default)]
592pub struct InMemoryCache {
593    inner: Arc<RwLock<HashMap<String, (Box<dyn Any + Send + Sync>, Instant, Option<Duration>, StdTypeId)>>>,
594    stats: Arc<RwLock<CacheStats>>,
595}
596
597#[derive(Default, Clone, serde::Serialize)]
598pub struct CacheStats {
599    hits: u64,
600    misses: u64,
601    puts: u64,
602    evicts: u64,
603}
604
605impl InMemoryCache {
606    pub fn get_typed<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<T> {
607        let mut guard = self.inner.write().unwrap();
608        if let Some((boxed_any, inserted, ttl_opt, type_id)) = guard.get(key) {
609            // TTL expiry check
610            let expired = match ttl_opt {
611                Some(ttl) => *inserted + *ttl < Instant::now(),
612                None => false,
613            };
614            if expired {
615                guard.remove(key);
616                self.stats.write().unwrap().misses += 1;
617                return None;
618            }
619            if *type_id == StdTypeId::of::<T>() {
620                if let Some(arc_val) = (&**boxed_any as &dyn Any).downcast_ref::<Arc<T>>() {
621                    self.stats.write().unwrap().hits += 1;
622                    return Some((**arc_val).clone());
623                }
624            }
625            self.stats.write().unwrap().misses += 1;
626            None
627        } else {
628            self.stats.write().unwrap().misses += 1;
629            None
630        }
631    }
632
633    pub fn put_typed<T: Clone + Send + Sync + 'static>(&self, key: String, value: T, ttl: Option<Duration>) {
634        let boxed: Box<dyn Any + Send + Sync> = Box::new(Arc::new(value));
635        self.inner
636            .write()
637            .unwrap()
638            .insert(key, (boxed, Instant::now(), ttl, StdTypeId::of::<T>()));
639        self.stats.write().unwrap().puts += 1;
640    }
641
642    pub fn evict(&self, key: &str) {
643        self.inner.write().unwrap().remove(key);
644        self.stats.write().unwrap().evicts += 1;
645    }
646
647    pub fn stats(&self) -> CacheStats { self.stats.read().unwrap().clone() }
648}
649
650static GLOBAL_CACHE: OnceCell<InMemoryCache> = OnceCell::new();
651
652pub fn cache_instance() -> InMemoryCache {
653    GLOBAL_CACHE.get().cloned().unwrap_or_default()
654}
655
656pub fn default_cache_key(fn_name: &str, args_json: &serde_json::Value) -> String {
657    format!("{}:{}", fn_name, args_json)
658}
659
660// Expose simple metrics endpoint for cache stats
661fn cache_metrics_router() -> Router {
662    let handler = || async move {
663        let stats = cache_instance().stats();
664        axum::Json(stats)
665    };
666    Router::new().route("/actuator/metrics/cache", get(handler))
667}
668
669fn print_banner_if_any() {
670    if let Ok(banner) = fs::read_to_string("banner.txt") {
671        println!("{}", banner);
672    } else {
673        // default banner with version and feature hints
674        println!(":: Spring-Axum v{} ::", env!("CARGO_PKG_VERSION"));
675        let mut features: Vec<&str> = Vec::new();
676        if cfg!(feature = "validator") { features.push("validator"); }
677        if cfg!(feature = "swagger") { features.push("swagger"); }
678        if cfg!(feature = "sqlx_postgres") { features.push("sqlx_postgres"); }
679        if cfg!(feature = "devtools") { features.push("devtools"); }
680        if !features.is_empty() {
681            println!("Enabled features: {}", features.join(", "));
682        }
683    }
684}
685
686// ---------------- Devtools Hot Reload (feature: devtools) -----------------
687#[cfg(feature = "devtools")]
688static DEVTOOLS_BUS: OnceCell<broadcast::Sender<String>> = OnceCell::new();
689
690#[cfg(feature = "devtools")]
691fn devtools_router() -> Router {
692    Router::new()
693        .route("/devtools/events", get(devtools_sse))
694        .route("/devtools/trigger", get(devtools_trigger))
695}
696
697#[cfg(feature = "devtools")]
698async fn devtools_sse() -> Sse<impl futures::Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
699    let tx = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).clone();
700    let rx = tx.subscribe();
701    let stream = BroadcastStream::new(rx).filter_map(|res| match res {
702        Ok(msg) => Some(Ok(SseEvent::default().event("reload").data(msg))),
703        Err(_) => None,
704    });
705    Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10)).text("keep-alive"))
706}
707
708#[cfg(feature = "devtools")]
709async fn devtools_trigger() -> &'static str {
710    let _ = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).send("manual".into());
711    "ok"
712}
713
714#[cfg(feature = "devtools")]
715fn start_devtools_watcher(cfg: &AppConfig) {
716    if !cfg.devtools.enabled { return; }
717    let tx = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).clone();
718    let paths = cfg.devtools.watch.clone();
719    std::thread::spawn(move || {
720        use std::sync::mpsc::channel;
721        let (evt_tx, evt_rx) = channel();
722        let mut watcher = RecommendedWatcher::new(move |res| {
723            let _ = evt_tx.send(res);
724        }, NotifyConfig::default()).expect("Failed to init file watcher");
725        for p in &paths {
726            let _ = watcher.watch(std::path::Path::new(p), RecursiveMode::Recursive);
727        }
728        while let Ok(res) = evt_rx.recv() {
729            match res {
730                Ok(event) => {
731                    let _ = tx.send("reload".into());
732                    let _ = tx.send(format!("changed: {:?}", event.paths));
733                }
734                Err(e) => {
735                    let _ = tx.send(format!("watch_error: {}", e));
736                }
737            }
738        }
739    });
740}
741
742pub struct Component<T>(pub Arc<T>);
743
744impl<S, T> axum::extract::FromRequestParts<S> for Component<T>
745where
746    T: Send + Sync + 'static,
747{
748    type Rejection = (axum::http::StatusCode, String);
749
750    fn from_request_parts(
751        parts: &mut axum::http::request::Parts,
752        _state: &S,
753    ) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
754        let ctx_opt = parts.extensions.get::<Arc<ApplicationContext>>().cloned();
755        async move {
756            if let Some(ctx) = ctx_opt {
757                if let Some(v) = ctx.resolve::<T>() {
758                    return Ok(Component(v));
759                }
760            }
761            Err((axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Component not found".into()))
762        }
763    }
764}
765
766impl SpringApp {
767    pub fn with_component<T: Send + Sync + 'static>(self, value: T) -> Self {
768        self.ctx.register::<T>(value);
769        self
770    }
771
772    pub fn with_discovered_components(self) -> Self {
773        for reg in inventory::iter::<ComponentRegistration> {
774            let (type_id, arc_any) = (reg.init)(&self.ctx);
775            self.ctx.register_dynamic(type_id, arc_any);
776        }
777        self
778    }
779}
780
781// -------- Component auto-registration (inventory) --------
782pub struct ComponentRegistration {
783    pub init: fn(&ApplicationContext) -> (StdTypeId, Box<dyn Any + Send + Sync>),
784}
785
786inventory::collect!(ComponentRegistration);
787
788// -------- Controller auto-registration (inventory) --------
789pub struct ControllerRouterRegistration {
790    pub init: fn() -> Router,
791}
792
793inventory::collect!(ControllerRouterRegistration);
794
795impl SpringApp {
796    pub fn with_discovered_controllers(mut self) -> Self {
797        for reg in inventory::iter::<ControllerRouterRegistration> {
798            let router = (reg.init)();
799            self.router = self.router.merge(router);
800        }
801        self
802    }
803}
804
805// ---------------- Interceptor Layer -----------------
806pub trait Interceptor: Clone + Send + Sync + 'static {
807    fn on_request(&self, req: Request<Body>) -> Request<Body> { req }
808    fn on_response(&self, res: Response<Body>) -> Response<Body> { res }
809}
810
811// -------- Interceptor auto-registration (inventory) --------
812pub struct InterceptorRegistration {
813    pub apply: fn(Router) -> Router,
814}
815
816inventory::collect!(InterceptorRegistration);
817
818impl SpringApp {
819    pub fn with_discovered_interceptors(mut self) -> Self {
820        for reg in inventory::iter::<InterceptorRegistration> {
821            self.router = (reg.apply)(self.router);
822        }
823        self
824    }
825}
826
827#[derive(Clone)]
828pub struct InterceptorLayer<I> {
829    interceptor: I,
830}
831
832impl<I> InterceptorLayer<I> {
833    pub fn new(interceptor: I) -> Self { Self { interceptor } }
834}
835
836impl<S, I> Layer<S> for InterceptorLayer<I>
837where
838    I: Interceptor,
839{
840    type Service = InterceptorService<S, I>;
841    fn layer(&self, inner: S) -> Self::Service { InterceptorService { inner, interceptor: self.interceptor.clone() } }
842}
843
844#[derive(Clone)]
845pub struct InterceptorService<S, I> {
846    inner: S,
847    interceptor: I,
848}
849
850impl<S, I> Service<Request<Body>> for InterceptorService<S, I>
851where
852    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
853    S::Future: Send + 'static,
854    I: Interceptor,
855{
856    type Response = Response<Body>;
857    type Error = S::Error;
858    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
859
860    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
861        self.inner.poll_ready(cx)
862    }
863
864    fn call(&mut self, req: Request<Body>) -> Self::Future {
865        let interceptor = self.interceptor.clone();
866        let mut inner = self.inner.clone();
867        let req = interceptor.on_request(req);
868        Box::pin(async move {
869            let res = inner.call(req).await?;
870            Ok(interceptor.on_response(res))
871        })
872    }
873}
874
875impl SpringApp {
876    pub fn with_interceptor<I: Interceptor>(mut self, interceptor: I) -> Self {
877        self.router = self.router.layer(InterceptorLayer::new(interceptor));
878        self
879    }
880}
881
882// ---------------- Advice (Pointcut-based) -----------------
883#[derive(Clone)]
884pub struct Pointcut {
885    methods: Option<Vec<axum::http::Method>>,
886    path_re: Option<Regex>,
887}
888
889impl Pointcut {
890    pub fn matches(&self, req: &Request<Body>) -> bool {
891        let method_ok = match &self.methods {
892            Some(ms) => ms.iter().any(|m| m == req.method()),
893            None => true,
894        };
895        let path_ok = match &self.path_re {
896            Some(re) => re.is_match(req.uri().path()),
897            None => true,
898        };
899        method_ok && path_ok
900    }
901}
902
903fn glob_to_regex(glob: &str) -> String {
904    let mut s = String::from("^");
905    let mut i = 0;
906    let bytes = glob.as_bytes();
907    while i < bytes.len() {
908        match bytes[i] as char {
909            '*' => {
910                if i + 1 < bytes.len() && bytes[i + 1] as char == '*' {
911                    s.push_str(".*");
912                    i += 2;
913                } else {
914                    s.push_str("[^/]*");
915                    i += 1;
916                }
917            }
918            '.' | '+' | '?' | '(' | ')' | '|' | '{' | '}' | '[' | ']' | '^' | '$' => {
919                s.push('\\');
920                s.push(bytes[i] as char);
921                i += 1;
922            }
923            c => {
924                s.push(c);
925                i += 1;
926            }
927        }
928    }
929    s.push('$');
930    s
931}
932
933fn parse_pointcut(expr: &str) -> Pointcut {
934    let parts: Vec<&str> = expr.split_whitespace().collect();
935    if parts.is_empty() {
936        return Pointcut { methods: None, path_re: None };
937    }
938    if parts.len() == 1 {
939        let re = Regex::new(&glob_to_regex(parts[0])).ok();
940        return Pointcut { methods: None, path_re: re };
941    }
942    let methods: Option<Vec<axum::http::Method>> = {
943        let ms = parts[0]
944            .split(',')
945            .filter_map(|m| axum::http::Method::from_bytes(m.trim().as_bytes()).ok())
946            .collect::<Vec<_>>();
947        if ms.is_empty() { None } else { Some(ms) }
948    };
949    let path_re = Regex::new(&glob_to_regex(parts[1])).ok();
950    Pointcut { methods, path_re }
951}
952
953pub trait Advice: Send + Sync + 'static {
954    fn before(&self, req: Request<Body>) -> Request<Body> { req }
955    fn after(&self, res: Response<Body>) -> Response<Body> { res }
956    fn before_async(&self, req: Request<Body>) -> futures::future::BoxFuture<'static, Request<Body>> {
957        Box::pin(async move { req })
958    }
959    fn after_async(&self, res: Response<Body>) -> futures::future::BoxFuture<'static, Response<Body>> {
960        Box::pin(async move { res })
961    }
962    fn pointcut_expr(&self) -> &'static str;
963    fn pointcut(&self) -> Pointcut { parse_pointcut(self.pointcut_expr()) }
964}
965
966#[derive(Clone)]
967pub struct AdviceLayer {
968    advice: Arc<dyn Advice>,
969    pointcut: Pointcut,
970}
971
972impl AdviceLayer {
973    pub fn new(advice: Box<dyn Advice>) -> Self { Self { pointcut: advice.pointcut(), advice: Arc::from(advice) } }
974}
975
976#[derive(Clone)]
977pub struct AdviceService<S> {
978    inner: S,
979    advice: Arc<dyn Advice>,
980    pointcut: Pointcut,
981}
982
983impl<S> Layer<S> for AdviceLayer {
984    type Service = AdviceService<S>;
985    fn layer(&self, inner: S) -> Self::Service { AdviceService { inner, advice: self.advice.clone(), pointcut: self.pointcut.clone() } }
986}
987
988impl<S> Service<Request<Body>> for AdviceService<S>
989where
990    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
991    S::Future: Send + 'static,
992{
993    type Response = Response<Body>;
994    type Error = S::Error;
995    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
996
997    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) }
998
999    fn call(&mut self, req: Request<Body>) -> Self::Future {
1000        let mut inner = self.inner.clone();
1001        let advice = self.advice.clone();
1002        let pointcut = self.pointcut.clone();
1003        let matched = pointcut.matches(&req);
1004        let req = if matched { (*advice).before(req) } else { req };
1005        Box::pin(async move {
1006            let req2 = if matched { (*advice).before_async(req).await } else { req };
1007            let res = inner.call(req2).await?;
1008            let res2 = if matched { (*advice).after(res) } else { res };
1009            let res3 = if matched { (*advice).after_async(res2).await } else { res2 };
1010            Ok(res3)
1011        })
1012    }
1013}
1014
1015// -------- Advice auto-registration (inventory) --------
1016pub struct AnyAdviceRegistration {
1017    pub build_boxed: fn() -> Box<dyn Advice>,
1018}
1019
1020inventory::collect!(AnyAdviceRegistration);
1021
1022impl SpringApp {
1023    pub fn with_discovered_advices(mut self) -> Self {
1024        for reg in inventory::iter::<AnyAdviceRegistration> {
1025            let boxed = (reg.build_boxed)();
1026            self.router = self.router.layer(AdviceLayer::new(boxed));
1027        }
1028        // Config-driven TransactionAdvice (optional)
1029        if self.config.tx_advice.enabled {
1030            let expr = self
1031                .config
1032                .tx_advice
1033                .pointcut
1034                .as_deref()
1035                .unwrap_or("method:* path:/**");
1036            // Leak string to 'static for Advice API; done once at startup.
1037            let expr_static: &'static str = Box::leak(expr.to_string().into_boxed_str());
1038            let boxed: Box<dyn Advice> = Box::new(TransactionAdvice::new(expr_static));
1039            self.router = self.router.layer(AdviceLayer::new(boxed));
1040        }
1041        self
1042    }
1043}
1044
1045// ---------------- Application Events -----------------
1046pub struct EventListenerRegistration {
1047    pub matches: fn(StdTypeId) -> bool,
1048    pub handle: fn(&dyn Any, &ApplicationContext),
1049}
1050
1051inventory::collect!(EventListenerRegistration);
1052
1053pub fn publish_event<T: Send + Sync + 'static>(evt: T, ctx: &ApplicationContext) {
1054    let type_id = StdTypeId::of::<T>();
1055    for reg in inventory::iter::<EventListenerRegistration> {
1056        if (reg.matches)(type_id) {
1057            (reg.handle)(&evt, ctx);
1058        }
1059    }
1060}
1061
1062#[derive(Clone, Debug)]
1063pub struct AppStarting {
1064    pub config: AppConfig,
1065}
1066
1067#[macro_export]
1068macro_rules! event_listener {
1069    ($ty:ty, $handler:path) => {
1070        inventory::submit! {
1071            ::spring_axum::EventListenerRegistration {
1072                matches: |incoming: ::std::any::TypeId| incoming == ::std::any::TypeId::of::<$ty>(),
1073                handle: |evt: &dyn ::std::any::Any, ctx: &::spring_axum::ApplicationContext| {
1074                    if let Some(v) = evt.downcast_ref::<$ty>() {
1075                        $handler(v, ctx);
1076                    }
1077                },
1078            }
1079        }
1080    };
1081}
1082
1083#[macro_export]
1084macro_rules! advice {
1085    ($ty:ty) => {
1086        inventory::submit! {
1087            ::spring_axum::AnyAdviceRegistration {
1088                build_boxed: || Box::new(<$ty as Default>::default()),
1089            }
1090        }
1091    };
1092}
1093
1094#[derive(Clone)]
1095pub struct TransactionAdvice {
1096    expr: &'static str,
1097}
1098
1099impl TransactionAdvice {
1100    pub const fn new(expr: &'static str) -> Self { Self { expr } }
1101}
1102
1103impl Advice for TransactionAdvice {
1104    fn pointcut_expr(&self) -> &'static str { self.expr }
1105    fn before_async(&self, req: Request<Body>) -> futures::future::BoxFuture<'static, Request<Body>> {
1106        Box::pin(async move {
1107            if let Some(mgr) = GLOBAL_TX_MANAGER.get() {
1108                let _ = mgr.begin().await;
1109            }
1110            req
1111        })
1112    }
1113    fn after_async(&self, res: Response<Body>) -> futures::future::BoxFuture<'static, Response<Body>> {
1114        Box::pin(async move {
1115            if let Some(mgr) = GLOBAL_TX_MANAGER.get() {
1116                if res.status().is_success() {
1117                    let _ = mgr.commit().await;
1118                } else {
1119                    let _ = mgr.rollback().await;
1120                }
1121            }
1122            res
1123        })
1124    }
1125}
1126
1127#[macro_export]
1128macro_rules! transaction_advice {
1129    ($expr:literal) => {
1130        inventory::submit! {
1131            ::spring_axum::AnyAdviceRegistration {
1132                build_boxed: || Box::new(::spring_axum::TransactionAdvice::new($expr)),
1133            }
1134        }
1135    };
1136}
1137
1138// ---------------- Swagger Placeholder (feature = "swagger") -----------------
1139#[cfg(feature = "swagger")]
1140fn build_openapi() -> utoipa::openapi::OpenApi {
1141    OpenApiBuilder::new()
1142        .info(InfoBuilder::new().title("Spring Axum").version("0.1.0").build())
1143        .paths(PathsBuilder::new().build())
1144        .build()
1145}
1146
1147#[cfg(feature = "swagger")]
1148fn swagger_routes() -> Router {
1149    fn ui_html() -> String {
1150        r#"<!DOCTYPE html>
1151<html>
1152  <head>
1153    <meta charset=\"utf-8\" />
1154    <title>Swagger UI</title>
1155  </head>
1156  <body>
1157    <h1>Swagger UI (placeholder)</h1>
1158    <p>OpenAPI JSON at <a href=\"/api-docs/openapi.json\">/api-docs/openapi.json</a></p>
1159  </body>
1160</html>"#.to_string()
1161    }
1162    let openapi = OPENAPI.get_or_init(|| build_openapi()).clone();
1163    let openapi_json = serde_json::to_value(&openapi).unwrap_or_else(|_| serde_json::json!({
1164        "openapi": "3.0.0",
1165        "info": {"title": "Spring Axum", "version": "0.1.0"},
1166        "paths": {}
1167    }));
1168    let _ = OPENAPI_JSON.set(openapi_json);
1169    Router::new()
1170        .route("/swagger-ui", get(|| async { Html(ui_html()) }))
1171        .route("/api-docs/openapi.json", get(|| async {
1172            let v = OPENAPI_JSON
1173                .get()
1174                .cloned()
1175                .unwrap_or_else(|| serde_json::json!({"openapi":"3.0.0","info":{"title":"Spring Axum","version":"0.1.0"},"paths":{}}));
1176            axum::Json(v)
1177        }))
1178}
1179#[cfg(feature = "sqlx_postgres")]
1180use sqlx::{Pool, Postgres, PgConnection};
1181#[cfg(feature = "sqlx_postgres")]
1182static GLOBAL_SQLX_POOL: OnceCell<Pool<Postgres>> = OnceCell::new();
1183
1184#[cfg(feature = "sqlx_postgres")]
1185pub fn set_sqlx_pool(pool: Pool<Postgres>) {
1186    let _ = GLOBAL_SQLX_POOL.set(pool);
1187}
1188
1189#[cfg(feature = "sqlx_postgres")]
1190pub fn sqlx_pool() -> Option<Pool<Postgres>> { GLOBAL_SQLX_POOL.get().cloned() }
1191
1192#[cfg(feature = "sqlx_postgres")]
1193pub async fn sqlx_transaction<F, Fut, T>(f: F) -> AppResult<T>
1194where
1195    F: for<'a> FnOnce(&'a mut PgConnection) -> Fut,
1196    Fut: std::future::Future<Output = AppResult<T>> + Send,
1197    T: Send + 'static,
1198{
1199    let pool = GLOBAL_SQLX_POOL
1200        .get()
1201        .cloned()
1202        .ok_or_else(|| AppError::Internal("sqlx pool not set".into()))?;
1203    let mut conn = pool
1204        .acquire()
1205        .await
1206        .map_err(|e| AppError::Internal(format!("acquire conn error: {}", e)))?;
1207    sqlx::query("BEGIN")
1208        .execute(conn.as_mut())
1209        .await
1210        .map_err(|e| AppError::Internal(format!("begin tx error: {}", e)))?;
1211    let result = f(conn.as_mut()).await;
1212    match result {
1213        Ok(val) => {
1214            sqlx::query("COMMIT")
1215                .execute(conn.as_mut())
1216                .await
1217                .map_err(|e| AppError::Internal(format!("commit error: {}", e)))?;
1218            Ok(val)
1219        }
1220        Err(err) => {
1221            let _ = sqlx::query("ROLLBACK").execute(conn.as_mut()).await;
1222            Err(err)
1223        }
1224    }
1225}