1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 fs,
5 net::SocketAddr,
6 sync::{Arc, RwLock},
7};
8
9use anyhow::Result;
10use tower_http::{cors::CorsLayer, trace::TraceLayer};
12use tower_http::add_extension::AddExtensionLayer;
13use tower::{Layer, Service};
14use axum::body::Body;
15use axum::http::{Request, Response};
16use futures::future::BoxFuture;
17use axum::response::IntoResponse;
18use serde::de::DeserializeOwned;
19use tracing_subscriber::{fmt, EnvFilter};
20use inventory;
21use std::any::TypeId as StdTypeId;
22use once_cell::sync::OnceCell;
23use std::time::{Duration, Instant};
24use axum::response::Html;
25use regex::Regex;
26#[cfg(feature = "devtools")]
27use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
28#[cfg(feature = "devtools")]
29use tokio::sync::broadcast;
30#[cfg(feature = "devtools")]
31use tokio_stream::wrappers::BroadcastStream;
32#[cfg(feature = "devtools")]
33use tokio_stream::StreamExt;
34#[cfg(feature = "devtools")]
35use notify::{Config as NotifyConfig, RecommendedWatcher, RecursiveMode};
36#[cfg(feature = "devtools")]
37use notify::Watcher;
38#[cfg(feature = "swagger")]
39use utoipa::openapi::{OpenApiBuilder, InfoBuilder, PathsBuilder};
40#[cfg(feature = "swagger")]
41static OPENAPI: OnceCell<utoipa::openapi::OpenApi> = OnceCell::new();
42#[cfg(feature = "swagger")]
43static OPENAPI_JSON: OnceCell<serde_json::Value> = OnceCell::new();
44
45pub use axum::Router;
47pub use axum::routing::{delete, get, post, put};
48pub use axum::extract::Path;
49pub use axum::extract::Json;
50pub use axum::extract::rejection::JsonRejection;
51pub use axum::http::StatusCode;
52pub use axum::body::Body as AxumBody;
53pub use axum::http::{Request as AxumRequest, Response as AxumResponse};
54pub use spring_axum_macros::*;
56
57#[macro_export]
61macro_rules! mybatis_init {
62 ($dir:expr) => {{
63 if let Err(e) = ::spring_axum_mybatis::init_global_from_dir_once($dir) {
64 ::tracing::error!(error = %e, "failed to init mybatis from dir");
65 panic!("MyBatis init failed: {}", e);
66 }
67 }};
68}
69
70#[macro_export]
72macro_rules! mybatis_registry {
73 () => {{
74 ::spring_axum_mybatis::global_registry()
75 .expect("MyBatis global registry not initialized; call mybatis_init!(...) early")
76 }};
77}
78
79#[macro_export]
82macro_rules! mybatis_exec {
83 ($executor:expr, $stmt_id:expr, $params_json:expr) => {{
84 let reg = $crate::mybatis_registry!();
85 let (sql, ordered) = reg
86 .prepare_json($stmt_id, &$params_json)
87 .ok_or_else(|| $crate::AppError::Internal("statement not found".into()))?;
88 $executor
89 .execute(&sql, &ordered)
90 .map_err(|e| $crate::AppError::Internal(format!("execute error: {}", e)))?;
91 }};
92}
93
94#[macro_export]
100macro_rules! sql_method {
101 ($fn_name:ident ( $($arg:ident : $typ:ty),* $(,)? ) ; namespace = $ns:literal) => {
102 pub fn $fn_name(&self, $($arg: $typ),*) -> ::spring_axum::AppResult<()> {
103 let params = ::serde_json::json!({ $(stringify!($arg): $arg),* });
104 let exec = ::spring_axum_mybatis::NoopExecutor::default();
105 ::spring_axum::mybatis_exec!(exec, concat!($ns, ".", stringify!($fn_name)), params);
106 Ok(())
107 }
108 };
109}
110
111#[macro_export]
114macro_rules! auto_discover {
115 () => {{
116 ::spring_axum::SpringApp::new()
117 .with_discovered_components()
118 .with_discovered_interceptors()
119 .with_discovered_advices()
120 .with_discovered_controllers()
121 }};
122}
123
124#[derive(Debug, Clone, serde::Serialize)]
126pub struct ErrorResponse { pub error: String }
127
128#[derive(Debug, Clone, serde::Serialize)]
129pub struct ValidationErrorsBody {
130 pub errors: HashMap<String, Vec<String>>,
131}
132
133#[derive(Debug, Clone)]
134pub enum AppError {
135 BadRequest(String),
136 NotFound(String),
137 Internal(String),
138 Validation(ValidationErrorsBody),
139 Unauthorized(String),
140 Forbidden(String),
141}
142
143pub type AppResult<T> = std::result::Result<T, AppError>;
144
145impl IntoResponse for AppError {
146 fn into_response(self) -> AxumResponse<AxumBody> {
147 match self {
148 AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, axum::Json(ErrorResponse { error: msg })).into_response(),
149 AppError::NotFound(msg) => (StatusCode::NOT_FOUND, axum::Json(ErrorResponse { error: msg })).into_response(),
150 AppError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(ErrorResponse { error: msg })).into_response(),
151 AppError::Validation(body) => (StatusCode::BAD_REQUEST, axum::Json(body)).into_response(),
152 AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, axum::Json(ErrorResponse { error: msg })).into_response(),
153 AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, axum::Json(ErrorResponse { error: msg })).into_response(),
154 }
155 }
156}
157
158impl From<JsonRejection> for AppError {
159 fn from(err: JsonRejection) -> Self { AppError::BadRequest(err.to_string()) }
160}
161
162#[cfg(feature = "validator")]
163impl From<validator::ValidationErrors> for AppError {
164 fn from(errs: validator::ValidationErrors) -> Self {
165 let mut map: HashMap<String, Vec<String>> = HashMap::new();
166 for (field, field_errs) in errs.field_errors() {
167 let messages: Vec<String> = field_errs
168 .iter()
169 .map(|e| {
170 if let Some(msg) = e.message.as_ref() {
171 msg.to_string()
172 } else {
173 e.code.to_string()
175 }
176 })
177 .collect();
178 map.insert(field.to_string(), messages);
179 }
180 AppError::Validation(ValidationErrorsBody { errors: map })
181 }
182}
183
184#[cfg(feature = "validator")]
186pub struct ValidatedJson<T>(pub T);
187
188#[cfg(feature = "validator")]
189impl<T, S> axum::extract::FromRequest<S> for ValidatedJson<T>
190where
191 T: DeserializeOwned + validator::Validate,
192 S: Send + Sync,
193{
194 type Rejection = AppError;
195 fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
196 let fut = axum::extract::Json::<T>::from_request(req, state);
197 async move {
198 match fut.await {
199 Ok(axum::extract::Json(v)) => {
200 if let Err(e) = v.validate() { return Err(AppError::from(e)); }
201 Ok(ValidatedJson(v))
202 }
203 Err(rej) => Err(AppError::from(rej)),
204 }
205 }
206 }
207}
208
209#[cfg(feature = "validator")]
210pub struct ValidatedQuery<T>(pub T);
211
212#[cfg(feature = "validator")]
213impl<T, S> axum::extract::FromRequestParts<S> for ValidatedQuery<T>
214where
215 T: DeserializeOwned + validator::Validate,
216 S: Send + Sync,
217{
218 type Rejection = AppError;
219 fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
220 let fut = axum::extract::Query::<T>::from_request_parts(parts, state);
221 async move {
222 match fut.await {
223 Ok(axum::extract::Query(v)) => {
224 if let Err(e) = v.validate() { return Err(AppError::from(e)); }
225 Ok(ValidatedQuery(v))
226 }
227 Err(rej) => Err(AppError::BadRequest(rej.to_string())),
228 }
229 }
230 }
231}
232
233#[cfg(feature = "validator")]
234pub struct ValidatedForm<T>(pub T);
235
236#[cfg(feature = "validator")]
237impl<T, S> axum::extract::FromRequest<S> for ValidatedForm<T>
238where
239 T: DeserializeOwned + validator::Validate,
240 S: Send + Sync,
241{
242 type Rejection = AppError;
243 fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
244 let fut = axum::extract::Form::<T>::from_request(req, state);
245 async move {
246 match fut.await {
247 Ok(axum::extract::Form(v)) => {
248 if let Err(e) = v.validate() { return Err(AppError::from(e)); }
249 Ok(ValidatedForm(v))
250 }
251 Err(rej) => Err(AppError::BadRequest(rej.to_string())),
252 }
253 }
254 }
255}
256
257#[cfg(all(feature = "validator", feature = "axum_extra_json"))]
259pub struct ValidatedJsonStream<T>(pub T);
260
261#[cfg(all(feature = "validator", feature = "axum_extra_json"))]
262impl<T, S> axum::extract::FromRequest<S> for ValidatedJsonStream<T>
263where
264 T: DeserializeOwned + validator::Validate + 'static,
265 S: Send + Sync,
266{
267 type Rejection = AppError;
268 fn from_request(req: AxumRequest<AxumBody>, state: &S) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
269 let fut = axum_extra::extract::JsonDeserializer::<T>::from_request(req, state);
270 async move {
271 match fut.await {
272 Ok(deserializer) => match deserializer.deserialize() {
273 Ok(v) => {
274 if let Err(e) = v.validate() { return Err(AppError::from(e)); }
275 Ok(ValidatedJsonStream(v))
276 }
277 Err(rej) => Err(AppError::BadRequest(rej.to_string())),
278 },
279 Err(rej) => Err(AppError::BadRequest(rej.to_string())),
280 }
281 }
282 }
283}
284
285#[derive(Clone, serde::Deserialize, serde::Serialize, Debug)]
286pub struct ServerConfig {
287 pub host: String,
288 pub port: u16,
289 #[serde(default)]
290 pub enable_cors: bool,
291}
292
293impl Default for ServerConfig {
294 fn default() -> Self {
295 Self {
296 host: "127.0.0.1".into(),
297 port: 3000,
298 enable_cors: true,
299 }
300 }
301}
302
303#[derive(Clone, serde::Deserialize, serde::Serialize, Debug, Default)]
304pub struct AppConfig {
305 #[serde(default)]
306 pub server: ServerConfig,
307 #[serde(default)]
308 pub devtools: DevtoolsConfig,
309 #[serde(default)]
310 pub tx_advice: TxAdviceConfig,
311}
312
313#[derive(Clone, serde::Deserialize, serde::Serialize, Debug)]
314pub struct DevtoolsConfig {
315 #[serde(default)]
316 pub enabled: bool,
317 #[serde(default)]
318 pub watch: Vec<String>,
319}
320
321impl Default for DevtoolsConfig {
322 fn default() -> Self {
323 Self { enabled: false, watch: vec!["src".into(), "resources".into(), "templates".into()] }
324 }
325}
326
327#[derive(Clone, serde::Deserialize, serde::Serialize, Debug, Default)]
328pub struct TxAdviceConfig {
329 #[serde(default)]
330 pub enabled: bool,
331 #[serde(default)]
332 pub pointcut: Option<String>,
333}
334
335fn load_config() -> AppConfig {
336 let base_path = "application.yaml";
337 let mut config = match fs::read_to_string(base_path) {
338 Ok(s) => serde_yaml::from_str::<AppConfig>(&s).unwrap_or_default(),
339 Err(_) => AppConfig::default(),
340 };
341
342 if let Ok(profile) = std::env::var("SPRING_PROFILE") {
343 let prof_path = format!("application-{}.yaml", profile);
344 if let Ok(s) = fs::read_to_string(&prof_path) {
345 if let Ok(overlay) = serde_yaml::from_str::<AppConfig>(&s) {
346 config.server = overlay.server;
348 }
349 }
350 }
351
352 if let Ok(host) = std::env::var("SERVER_HOST") {
354 config.server.host = host;
355 }
356 if let Ok(port) = std::env::var("SERVER_PORT") {
357 if let Ok(p) = port.parse::<u16>() {
358 config.server.port = p;
359 }
360 }
361 if let Ok(cors) = std::env::var("SERVER_ENABLE_CORS") {
362 config.server.enable_cors = matches!(cors.as_str(), "1" | "true" | "True" | "TRUE");
363 }
364 if let Ok(enabled) = std::env::var("DEVTOOLS_ENABLED") {
366 config.devtools.enabled = matches!(enabled.as_str(), "1" | "true" | "True" | "TRUE");
367 }
368 if let Ok(watch) = std::env::var("DEVTOOLS_WATCH") {
369 let list: Vec<String> = watch
370 .split([',', ';', '|'])
371 .map(|s| s.trim().to_string())
372 .filter(|s| !s.is_empty())
373 .collect();
374 if !list.is_empty() { config.devtools.watch = list; }
375 }
376 config
377}
378
379fn validate_config(cfg: &AppConfig) {
380 if cfg.server.host.trim().is_empty() {
381 tracing::warn!(target = "spring_axum", "server.host is empty; defaulting may cause bind errors");
382 }
383 if cfg.server.port == 0 {
384 tracing::warn!(target = "spring_axum", port = cfg.server.port, "server.port must be in 1..=65535");
385 }
386 if cfg.server.host.starts_with("http://") || cfg.server.host.starts_with("https://") {
387 tracing::warn!(target = "spring_axum", host = %cfg.server.host, "server.host should be a hostname/ip, not URL");
388 }
389 #[cfg(feature = "devtools")]
390 if cfg.devtools.enabled {
391 for p in &cfg.devtools.watch {
392 if !std::path::Path::new(p).exists() {
393 tracing::info!(target = "spring_axum", path = %p, "devtools watch path not found (will still attempt watching)");
394 }
395 }
396 }
397}
398
399pub trait Controller {
400 fn routes(&self) -> Router;
401}
402
403pub struct SpringApp {
404 router: Router,
405 config: AppConfig,
406 ctx: Arc<ApplicationContext>,
407}
408
409impl SpringApp {
410 pub fn new() -> Self {
411 let config = load_config();
412 validate_config(&config);
413 let router = Router::new().route("/actuator/health", get(|| async { "OK" }));
414 let ctx = Arc::new(ApplicationContext::new());
415 init_defaults_once();
416 let router = router.merge(cache_metrics_router());
418 #[cfg(feature = "swagger")]
419 let router = router.merge(swagger_routes());
420 #[cfg(feature = "devtools")]
421 let router = if config.devtools.enabled { router.merge(devtools_router()) } else { router };
422 #[cfg(feature = "devtools")]
423 if config.devtools.enabled { start_devtools_watcher(&config); }
424 Self { router, config, ctx }
425 }
426
427 pub fn with_controller(mut self, controller: impl Controller) -> Self {
428 self.router = self.router.merge(controller.routes());
429 self
430 }
431
432 pub fn with_router(mut self, router: Router) -> Self {
433 self.router = self.router.merge(router);
434 self
435 }
436
437 pub async fn run(self) -> Result<()> {
438 init_tracing();
439
440 auto_init_integrations();
442
443 print_banner_if_any();
444
445 tracing::info!(target = "spring_axum", host = %self.config.server.host, port = self.config.server.port, cors = self.config.server.enable_cors, devtools = self.config.devtools.enabled, "Startup configuration");
446
447 publish_event(AppStarting { config: self.config.clone() }, &self.ctx);
449
450 let mut router = self
451 .router
452 .layer(TraceLayer::new_for_http())
453 .layer(AddExtensionLayer::new(self.ctx.clone()));
454 if self.config.server.enable_cors {
455 router = router.layer(CorsLayer::permissive());
456 }
457
458 let addr: SocketAddr = format!("{}:{}", self.config.server.host, self.config.server.port)
459 .parse()
460 .expect("invalid server address");
461 tracing::info!(target: "spring_axum", "Listening on http://{}", addr);
462 axum::serve(tokio::net::TcpListener::bind(addr).await?, router).await?;
463 Ok(())
464 }
465}
466
467fn init_tracing() {
468 let _ = fmt()
469 .with_env_filter(
470 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
471 )
472 .with_target(false)
473 .compact()
474 .try_init();
475}
476
477pub struct ApplicationContext {
479 inner: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
480}
481
482impl ApplicationContext {
483 pub fn new() -> Self {
484 Self { inner: RwLock::new(HashMap::new()) }
485 }
486
487 pub fn register<T: Send + Sync + 'static>(&self, value: T) {
488 let arc: Arc<T> = Arc::new(value);
489 self.inner
490 .write()
491 .unwrap()
492 .insert(TypeId::of::<T>(), Box::new(arc));
493 }
494
495 pub fn register_dynamic(&self, type_id: StdTypeId, arc_any: Box<dyn Any + Send + Sync>) {
496 self.inner.write().unwrap().insert(type_id, arc_any);
497 }
498
499 pub fn resolve<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
500 self.inner
501 .read()
502 .unwrap()
503 .get(&TypeId::of::<T>())
504 .and_then(|b| b.downcast_ref::<Arc<T>>().cloned())
505 }
506}
507
508pub trait TransactionManager: Send + Sync + 'static {
510 fn begin(&self) -> BoxFuture<'static, Result<(), AppError>>;
511 fn commit(&self) -> BoxFuture<'static, Result<(), AppError>>;
512 fn rollback(&self) -> BoxFuture<'static, Result<(), AppError>>;
513}
514
515#[derive(Clone, Default)]
516pub struct NoopTransactionManager;
517
518impl TransactionManager for NoopTransactionManager {
519 fn begin(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
520 fn commit(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
521 fn rollback(&self) -> BoxFuture<'static, Result<(), AppError>> { Box::pin(async { Ok(()) }) }
522}
523
524static GLOBAL_TX_MANAGER: OnceCell<Arc<dyn TransactionManager>> = OnceCell::new();
525
526fn init_defaults_once() {
527 let _ = GLOBAL_TX_MANAGER.set(Arc::new(NoopTransactionManager::default()));
529 let _ = GLOBAL_CACHE.set(InMemoryCache::default());
530}
531
532fn auto_init_integrations() {
534 let mybatis_dir = std::path::Path::new("resources/mybatis");
536 if mybatis_dir.exists() {
537 if spring_axum_mybatis::global_registry().is_none() {
538 match spring_axum_mybatis::init_global_from_dir_once("resources/mybatis") {
539 Ok(_) => tracing::info!(target = "spring_axum", path = %mybatis_dir.display(), "MyBatis registry initialized"),
540 Err(e) => tracing::warn!(target = "spring_axum", error = %e, "Failed to init MyBatis registry"),
541 }
542 }
543 } else {
544 tracing::debug!(target = "spring_axum", "resources/mybatis not found; skipping MyBatis init");
545 }
546
547 #[cfg(feature = "sqlx_postgres")]
549 {
550 if GLOBAL_SQLX_POOL.get().is_none() {
551 match std::env::var("DATABASE_URL") {
552 Ok(url) => {
553 match sqlx::postgres::PgPoolOptions::new().connect_lazy(&url) {
554 Ok(pool) => {
555 let _ = GLOBAL_SQLX_POOL.set(pool);
556 tracing::info!(target = "spring_axum", "sqlx pool initialized from env DATABASE_URL");
557 }
558 Err(e) => tracing::warn!(target = "spring_axum", error = %e, "Failed to lazy connect sqlx pool"),
559 }
560 }
561 Err(_) => tracing::debug!(target = "spring_axum", "DATABASE_URL not set; skipping sqlx pool init"),
562 }
563 }
564 }
565}
566
567pub fn set_transaction_manager<T: TransactionManager>(mgr: Arc<T>) {
568 let _ = GLOBAL_TX_MANAGER.set(mgr);
569}
570
571pub async fn transaction<F, Fut, T>(f: F) -> AppResult<T>
572where
573 F: FnOnce() -> Fut,
574 Fut: std::future::Future<Output = AppResult<T>>,
575{
576 let mgr = GLOBAL_TX_MANAGER.get().cloned().unwrap_or_else(|| Arc::new(NoopTransactionManager));
577 if let Err(e) = mgr.begin().await { return Err(e); }
578 match f().await {
579 Ok(val) => {
580 if let Err(e) = mgr.commit().await { return Err(e); }
581 Ok(val)
582 }
583 Err(err) => {
584 let _ = mgr.rollback().await;
585 Err(err)
586 }
587 }
588}
589
590#[derive(Clone, Default)]
592pub struct InMemoryCache {
593 inner: Arc<RwLock<HashMap<String, (Box<dyn Any + Send + Sync>, Instant, Option<Duration>, StdTypeId)>>>,
594 stats: Arc<RwLock<CacheStats>>,
595}
596
597#[derive(Default, Clone, serde::Serialize)]
598pub struct CacheStats {
599 hits: u64,
600 misses: u64,
601 puts: u64,
602 evicts: u64,
603}
604
605impl InMemoryCache {
606 pub fn get_typed<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<T> {
607 let mut guard = self.inner.write().unwrap();
608 if let Some((boxed_any, inserted, ttl_opt, type_id)) = guard.get(key) {
609 let expired = match ttl_opt {
611 Some(ttl) => *inserted + *ttl < Instant::now(),
612 None => false,
613 };
614 if expired {
615 guard.remove(key);
616 self.stats.write().unwrap().misses += 1;
617 return None;
618 }
619 if *type_id == StdTypeId::of::<T>() {
620 if let Some(arc_val) = (&**boxed_any as &dyn Any).downcast_ref::<Arc<T>>() {
621 self.stats.write().unwrap().hits += 1;
622 return Some((**arc_val).clone());
623 }
624 }
625 self.stats.write().unwrap().misses += 1;
626 None
627 } else {
628 self.stats.write().unwrap().misses += 1;
629 None
630 }
631 }
632
633 pub fn put_typed<T: Clone + Send + Sync + 'static>(&self, key: String, value: T, ttl: Option<Duration>) {
634 let boxed: Box<dyn Any + Send + Sync> = Box::new(Arc::new(value));
635 self.inner
636 .write()
637 .unwrap()
638 .insert(key, (boxed, Instant::now(), ttl, StdTypeId::of::<T>()));
639 self.stats.write().unwrap().puts += 1;
640 }
641
642 pub fn evict(&self, key: &str) {
643 self.inner.write().unwrap().remove(key);
644 self.stats.write().unwrap().evicts += 1;
645 }
646
647 pub fn stats(&self) -> CacheStats { self.stats.read().unwrap().clone() }
648}
649
650static GLOBAL_CACHE: OnceCell<InMemoryCache> = OnceCell::new();
651
652pub fn cache_instance() -> InMemoryCache {
653 GLOBAL_CACHE.get().cloned().unwrap_or_default()
654}
655
656pub fn default_cache_key(fn_name: &str, args_json: &serde_json::Value) -> String {
657 format!("{}:{}", fn_name, args_json)
658}
659
660fn cache_metrics_router() -> Router {
662 let handler = || async move {
663 let stats = cache_instance().stats();
664 axum::Json(stats)
665 };
666 Router::new().route("/actuator/metrics/cache", get(handler))
667}
668
669fn print_banner_if_any() {
670 if let Ok(banner) = fs::read_to_string("banner.txt") {
671 println!("{}", banner);
672 } else {
673 println!(":: Spring-Axum v{} ::", env!("CARGO_PKG_VERSION"));
675 let mut features: Vec<&str> = Vec::new();
676 if cfg!(feature = "validator") { features.push("validator"); }
677 if cfg!(feature = "swagger") { features.push("swagger"); }
678 if cfg!(feature = "sqlx_postgres") { features.push("sqlx_postgres"); }
679 if cfg!(feature = "devtools") { features.push("devtools"); }
680 if !features.is_empty() {
681 println!("Enabled features: {}", features.join(", "));
682 }
683 }
684}
685
686#[cfg(feature = "devtools")]
688static DEVTOOLS_BUS: OnceCell<broadcast::Sender<String>> = OnceCell::new();
689
690#[cfg(feature = "devtools")]
691fn devtools_router() -> Router {
692 Router::new()
693 .route("/devtools/events", get(devtools_sse))
694 .route("/devtools/trigger", get(devtools_trigger))
695}
696
697#[cfg(feature = "devtools")]
698async fn devtools_sse() -> Sse<impl futures::Stream<Item = Result<SseEvent, std::convert::Infallible>>> {
699 let tx = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).clone();
700 let rx = tx.subscribe();
701 let stream = BroadcastStream::new(rx).filter_map(|res| match res {
702 Ok(msg) => Some(Ok(SseEvent::default().event("reload").data(msg))),
703 Err(_) => None,
704 });
705 Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10)).text("keep-alive"))
706}
707
708#[cfg(feature = "devtools")]
709async fn devtools_trigger() -> &'static str {
710 let _ = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).send("manual".into());
711 "ok"
712}
713
714#[cfg(feature = "devtools")]
715fn start_devtools_watcher(cfg: &AppConfig) {
716 if !cfg.devtools.enabled { return; }
717 let tx = DEVTOOLS_BUS.get_or_init(|| broadcast::channel(128).0).clone();
718 let paths = cfg.devtools.watch.clone();
719 std::thread::spawn(move || {
720 use std::sync::mpsc::channel;
721 let (evt_tx, evt_rx) = channel();
722 let mut watcher = RecommendedWatcher::new(move |res| {
723 let _ = evt_tx.send(res);
724 }, NotifyConfig::default()).expect("Failed to init file watcher");
725 for p in &paths {
726 let _ = watcher.watch(std::path::Path::new(p), RecursiveMode::Recursive);
727 }
728 while let Ok(res) = evt_rx.recv() {
729 match res {
730 Ok(event) => {
731 let _ = tx.send("reload".into());
732 let _ = tx.send(format!("changed: {:?}", event.paths));
733 }
734 Err(e) => {
735 let _ = tx.send(format!("watch_error: {}", e));
736 }
737 }
738 }
739 });
740}
741
742pub struct Component<T>(pub Arc<T>);
743
744impl<S, T> axum::extract::FromRequestParts<S> for Component<T>
745where
746 T: Send + Sync + 'static,
747{
748 type Rejection = (axum::http::StatusCode, String);
749
750 fn from_request_parts(
751 parts: &mut axum::http::request::Parts,
752 _state: &S,
753 ) -> impl futures::Future<Output = Result<Self, Self::Rejection>> + Send {
754 let ctx_opt = parts.extensions.get::<Arc<ApplicationContext>>().cloned();
755 async move {
756 if let Some(ctx) = ctx_opt {
757 if let Some(v) = ctx.resolve::<T>() {
758 return Ok(Component(v));
759 }
760 }
761 Err((axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Component not found".into()))
762 }
763 }
764}
765
766impl SpringApp {
767 pub fn with_component<T: Send + Sync + 'static>(self, value: T) -> Self {
768 self.ctx.register::<T>(value);
769 self
770 }
771
772 pub fn with_discovered_components(self) -> Self {
773 for reg in inventory::iter::<ComponentRegistration> {
774 let (type_id, arc_any) = (reg.init)(&self.ctx);
775 self.ctx.register_dynamic(type_id, arc_any);
776 }
777 self
778 }
779}
780
781pub struct ComponentRegistration {
783 pub init: fn(&ApplicationContext) -> (StdTypeId, Box<dyn Any + Send + Sync>),
784}
785
786inventory::collect!(ComponentRegistration);
787
788pub struct ControllerRouterRegistration {
790 pub init: fn() -> Router,
791}
792
793inventory::collect!(ControllerRouterRegistration);
794
795impl SpringApp {
796 pub fn with_discovered_controllers(mut self) -> Self {
797 for reg in inventory::iter::<ControllerRouterRegistration> {
798 let router = (reg.init)();
799 self.router = self.router.merge(router);
800 }
801 self
802 }
803}
804
805pub trait Interceptor: Clone + Send + Sync + 'static {
807 fn on_request(&self, req: Request<Body>) -> Request<Body> { req }
808 fn on_response(&self, res: Response<Body>) -> Response<Body> { res }
809}
810
811pub struct InterceptorRegistration {
813 pub apply: fn(Router) -> Router,
814}
815
816inventory::collect!(InterceptorRegistration);
817
818impl SpringApp {
819 pub fn with_discovered_interceptors(mut self) -> Self {
820 for reg in inventory::iter::<InterceptorRegistration> {
821 self.router = (reg.apply)(self.router);
822 }
823 self
824 }
825}
826
827#[derive(Clone)]
828pub struct InterceptorLayer<I> {
829 interceptor: I,
830}
831
832impl<I> InterceptorLayer<I> {
833 pub fn new(interceptor: I) -> Self { Self { interceptor } }
834}
835
836impl<S, I> Layer<S> for InterceptorLayer<I>
837where
838 I: Interceptor,
839{
840 type Service = InterceptorService<S, I>;
841 fn layer(&self, inner: S) -> Self::Service { InterceptorService { inner, interceptor: self.interceptor.clone() } }
842}
843
844#[derive(Clone)]
845pub struct InterceptorService<S, I> {
846 inner: S,
847 interceptor: I,
848}
849
850impl<S, I> Service<Request<Body>> for InterceptorService<S, I>
851where
852 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
853 S::Future: Send + 'static,
854 I: Interceptor,
855{
856 type Response = Response<Body>;
857 type Error = S::Error;
858 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
859
860 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
861 self.inner.poll_ready(cx)
862 }
863
864 fn call(&mut self, req: Request<Body>) -> Self::Future {
865 let interceptor = self.interceptor.clone();
866 let mut inner = self.inner.clone();
867 let req = interceptor.on_request(req);
868 Box::pin(async move {
869 let res = inner.call(req).await?;
870 Ok(interceptor.on_response(res))
871 })
872 }
873}
874
875impl SpringApp {
876 pub fn with_interceptor<I: Interceptor>(mut self, interceptor: I) -> Self {
877 self.router = self.router.layer(InterceptorLayer::new(interceptor));
878 self
879 }
880}
881
882#[derive(Clone)]
884pub struct Pointcut {
885 methods: Option<Vec<axum::http::Method>>,
886 path_re: Option<Regex>,
887}
888
889impl Pointcut {
890 pub fn matches(&self, req: &Request<Body>) -> bool {
891 let method_ok = match &self.methods {
892 Some(ms) => ms.iter().any(|m| m == req.method()),
893 None => true,
894 };
895 let path_ok = match &self.path_re {
896 Some(re) => re.is_match(req.uri().path()),
897 None => true,
898 };
899 method_ok && path_ok
900 }
901}
902
903fn glob_to_regex(glob: &str) -> String {
904 let mut s = String::from("^");
905 let mut i = 0;
906 let bytes = glob.as_bytes();
907 while i < bytes.len() {
908 match bytes[i] as char {
909 '*' => {
910 if i + 1 < bytes.len() && bytes[i + 1] as char == '*' {
911 s.push_str(".*");
912 i += 2;
913 } else {
914 s.push_str("[^/]*");
915 i += 1;
916 }
917 }
918 '.' | '+' | '?' | '(' | ')' | '|' | '{' | '}' | '[' | ']' | '^' | '$' => {
919 s.push('\\');
920 s.push(bytes[i] as char);
921 i += 1;
922 }
923 c => {
924 s.push(c);
925 i += 1;
926 }
927 }
928 }
929 s.push('$');
930 s
931}
932
933fn parse_pointcut(expr: &str) -> Pointcut {
934 let parts: Vec<&str> = expr.split_whitespace().collect();
935 if parts.is_empty() {
936 return Pointcut { methods: None, path_re: None };
937 }
938 if parts.len() == 1 {
939 let re = Regex::new(&glob_to_regex(parts[0])).ok();
940 return Pointcut { methods: None, path_re: re };
941 }
942 let methods: Option<Vec<axum::http::Method>> = {
943 let ms = parts[0]
944 .split(',')
945 .filter_map(|m| axum::http::Method::from_bytes(m.trim().as_bytes()).ok())
946 .collect::<Vec<_>>();
947 if ms.is_empty() { None } else { Some(ms) }
948 };
949 let path_re = Regex::new(&glob_to_regex(parts[1])).ok();
950 Pointcut { methods, path_re }
951}
952
953pub trait Advice: Send + Sync + 'static {
954 fn before(&self, req: Request<Body>) -> Request<Body> { req }
955 fn after(&self, res: Response<Body>) -> Response<Body> { res }
956 fn before_async(&self, req: Request<Body>) -> futures::future::BoxFuture<'static, Request<Body>> {
957 Box::pin(async move { req })
958 }
959 fn after_async(&self, res: Response<Body>) -> futures::future::BoxFuture<'static, Response<Body>> {
960 Box::pin(async move { res })
961 }
962 fn pointcut_expr(&self) -> &'static str;
963 fn pointcut(&self) -> Pointcut { parse_pointcut(self.pointcut_expr()) }
964}
965
966#[derive(Clone)]
967pub struct AdviceLayer {
968 advice: Arc<dyn Advice>,
969 pointcut: Pointcut,
970}
971
972impl AdviceLayer {
973 pub fn new(advice: Box<dyn Advice>) -> Self { Self { pointcut: advice.pointcut(), advice: Arc::from(advice) } }
974}
975
976#[derive(Clone)]
977pub struct AdviceService<S> {
978 inner: S,
979 advice: Arc<dyn Advice>,
980 pointcut: Pointcut,
981}
982
983impl<S> Layer<S> for AdviceLayer {
984 type Service = AdviceService<S>;
985 fn layer(&self, inner: S) -> Self::Service { AdviceService { inner, advice: self.advice.clone(), pointcut: self.pointcut.clone() } }
986}
987
988impl<S> Service<Request<Body>> for AdviceService<S>
989where
990 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
991 S::Future: Send + 'static,
992{
993 type Response = Response<Body>;
994 type Error = S::Error;
995 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
996
997 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) }
998
999 fn call(&mut self, req: Request<Body>) -> Self::Future {
1000 let mut inner = self.inner.clone();
1001 let advice = self.advice.clone();
1002 let pointcut = self.pointcut.clone();
1003 let matched = pointcut.matches(&req);
1004 let req = if matched { (*advice).before(req) } else { req };
1005 Box::pin(async move {
1006 let req2 = if matched { (*advice).before_async(req).await } else { req };
1007 let res = inner.call(req2).await?;
1008 let res2 = if matched { (*advice).after(res) } else { res };
1009 let res3 = if matched { (*advice).after_async(res2).await } else { res2 };
1010 Ok(res3)
1011 })
1012 }
1013}
1014
1015pub struct AnyAdviceRegistration {
1017 pub build_boxed: fn() -> Box<dyn Advice>,
1018}
1019
1020inventory::collect!(AnyAdviceRegistration);
1021
1022impl SpringApp {
1023 pub fn with_discovered_advices(mut self) -> Self {
1024 for reg in inventory::iter::<AnyAdviceRegistration> {
1025 let boxed = (reg.build_boxed)();
1026 self.router = self.router.layer(AdviceLayer::new(boxed));
1027 }
1028 if self.config.tx_advice.enabled {
1030 let expr = self
1031 .config
1032 .tx_advice
1033 .pointcut
1034 .as_deref()
1035 .unwrap_or("method:* path:/**");
1036 let expr_static: &'static str = Box::leak(expr.to_string().into_boxed_str());
1038 let boxed: Box<dyn Advice> = Box::new(TransactionAdvice::new(expr_static));
1039 self.router = self.router.layer(AdviceLayer::new(boxed));
1040 }
1041 self
1042 }
1043}
1044
1045pub struct EventListenerRegistration {
1047 pub matches: fn(StdTypeId) -> bool,
1048 pub handle: fn(&dyn Any, &ApplicationContext),
1049}
1050
1051inventory::collect!(EventListenerRegistration);
1052
1053pub fn publish_event<T: Send + Sync + 'static>(evt: T, ctx: &ApplicationContext) {
1054 let type_id = StdTypeId::of::<T>();
1055 for reg in inventory::iter::<EventListenerRegistration> {
1056 if (reg.matches)(type_id) {
1057 (reg.handle)(&evt, ctx);
1058 }
1059 }
1060}
1061
1062#[derive(Clone, Debug)]
1063pub struct AppStarting {
1064 pub config: AppConfig,
1065}
1066
1067#[macro_export]
1068macro_rules! event_listener {
1069 ($ty:ty, $handler:path) => {
1070 inventory::submit! {
1071 ::spring_axum::EventListenerRegistration {
1072 matches: |incoming: ::std::any::TypeId| incoming == ::std::any::TypeId::of::<$ty>(),
1073 handle: |evt: &dyn ::std::any::Any, ctx: &::spring_axum::ApplicationContext| {
1074 if let Some(v) = evt.downcast_ref::<$ty>() {
1075 $handler(v, ctx);
1076 }
1077 },
1078 }
1079 }
1080 };
1081}
1082
1083#[macro_export]
1084macro_rules! advice {
1085 ($ty:ty) => {
1086 inventory::submit! {
1087 ::spring_axum::AnyAdviceRegistration {
1088 build_boxed: || Box::new(<$ty as Default>::default()),
1089 }
1090 }
1091 };
1092}
1093
1094#[derive(Clone)]
1095pub struct TransactionAdvice {
1096 expr: &'static str,
1097}
1098
1099impl TransactionAdvice {
1100 pub const fn new(expr: &'static str) -> Self { Self { expr } }
1101}
1102
1103impl Advice for TransactionAdvice {
1104 fn pointcut_expr(&self) -> &'static str { self.expr }
1105 fn before_async(&self, req: Request<Body>) -> futures::future::BoxFuture<'static, Request<Body>> {
1106 Box::pin(async move {
1107 if let Some(mgr) = GLOBAL_TX_MANAGER.get() {
1108 let _ = mgr.begin().await;
1109 }
1110 req
1111 })
1112 }
1113 fn after_async(&self, res: Response<Body>) -> futures::future::BoxFuture<'static, Response<Body>> {
1114 Box::pin(async move {
1115 if let Some(mgr) = GLOBAL_TX_MANAGER.get() {
1116 if res.status().is_success() {
1117 let _ = mgr.commit().await;
1118 } else {
1119 let _ = mgr.rollback().await;
1120 }
1121 }
1122 res
1123 })
1124 }
1125}
1126
1127#[macro_export]
1128macro_rules! transaction_advice {
1129 ($expr:literal) => {
1130 inventory::submit! {
1131 ::spring_axum::AnyAdviceRegistration {
1132 build_boxed: || Box::new(::spring_axum::TransactionAdvice::new($expr)),
1133 }
1134 }
1135 };
1136}
1137
1138#[cfg(feature = "swagger")]
1140fn build_openapi() -> utoipa::openapi::OpenApi {
1141 OpenApiBuilder::new()
1142 .info(InfoBuilder::new().title("Spring Axum").version("0.1.0").build())
1143 .paths(PathsBuilder::new().build())
1144 .build()
1145}
1146
1147#[cfg(feature = "swagger")]
1148fn swagger_routes() -> Router {
1149 fn ui_html() -> String {
1150 r#"<!DOCTYPE html>
1151<html>
1152 <head>
1153 <meta charset=\"utf-8\" />
1154 <title>Swagger UI</title>
1155 </head>
1156 <body>
1157 <h1>Swagger UI (placeholder)</h1>
1158 <p>OpenAPI JSON at <a href=\"/api-docs/openapi.json\">/api-docs/openapi.json</a></p>
1159 </body>
1160</html>"#.to_string()
1161 }
1162 let openapi = OPENAPI.get_or_init(|| build_openapi()).clone();
1163 let openapi_json = serde_json::to_value(&openapi).unwrap_or_else(|_| serde_json::json!({
1164 "openapi": "3.0.0",
1165 "info": {"title": "Spring Axum", "version": "0.1.0"},
1166 "paths": {}
1167 }));
1168 let _ = OPENAPI_JSON.set(openapi_json);
1169 Router::new()
1170 .route("/swagger-ui", get(|| async { Html(ui_html()) }))
1171 .route("/api-docs/openapi.json", get(|| async {
1172 let v = OPENAPI_JSON
1173 .get()
1174 .cloned()
1175 .unwrap_or_else(|| serde_json::json!({"openapi":"3.0.0","info":{"title":"Spring Axum","version":"0.1.0"},"paths":{}}));
1176 axum::Json(v)
1177 }))
1178}
1179#[cfg(feature = "sqlx_postgres")]
1180use sqlx::{Pool, Postgres, PgConnection};
1181#[cfg(feature = "sqlx_postgres")]
1182static GLOBAL_SQLX_POOL: OnceCell<Pool<Postgres>> = OnceCell::new();
1183
1184#[cfg(feature = "sqlx_postgres")]
1185pub fn set_sqlx_pool(pool: Pool<Postgres>) {
1186 let _ = GLOBAL_SQLX_POOL.set(pool);
1187}
1188
1189#[cfg(feature = "sqlx_postgres")]
1190pub fn sqlx_pool() -> Option<Pool<Postgres>> { GLOBAL_SQLX_POOL.get().cloned() }
1191
1192#[cfg(feature = "sqlx_postgres")]
1193pub async fn sqlx_transaction<F, Fut, T>(f: F) -> AppResult<T>
1194where
1195 F: for<'a> FnOnce(&'a mut PgConnection) -> Fut,
1196 Fut: std::future::Future<Output = AppResult<T>> + Send,
1197 T: Send + 'static,
1198{
1199 let pool = GLOBAL_SQLX_POOL
1200 .get()
1201 .cloned()
1202 .ok_or_else(|| AppError::Internal("sqlx pool not set".into()))?;
1203 let mut conn = pool
1204 .acquire()
1205 .await
1206 .map_err(|e| AppError::Internal(format!("acquire conn error: {}", e)))?;
1207 sqlx::query("BEGIN")
1208 .execute(conn.as_mut())
1209 .await
1210 .map_err(|e| AppError::Internal(format!("begin tx error: {}", e)))?;
1211 let result = f(conn.as_mut()).await;
1212 match result {
1213 Ok(val) => {
1214 sqlx::query("COMMIT")
1215 .execute(conn.as_mut())
1216 .await
1217 .map_err(|e| AppError::Internal(format!("commit error: {}", e)))?;
1218 Ok(val)
1219 }
1220 Err(err) => {
1221 let _ = sqlx::query("ROLLBACK").execute(conn.as_mut()).await;
1222 Err(err)
1223 }
1224 }
1225}