1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Ident, ItemFn, ItemStruct, ItemImpl, ImplItem, LitStr, Expr, FnArg, Type, TypePath};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::Token;
7
8fn route_macro_impl(method_ident: &str, path: LitStr, input: ItemFn, layer: Option<Expr>) -> TokenStream {
9 let fn_name: Ident = input.sig.ident.clone();
10 let route_fn_name = Ident::new(&format!("__spring_axum_route_{}", fn_name), fn_name.span());
11 let method_ident = Ident::new(method_ident, fn_name.span());
12 let base_route = quote! { ::spring_axum::Router::new().route(#path, ::spring_axum::#method_ident(#fn_name)) };
13 let router_stmt = if let Some(layer_expr) = layer {
14 quote! { #base_route.route_layer(#layer_expr) }
15 } else {
16 base_route
17 };
18 let expanded = quote! {
19 #input
20
21 #[allow(non_snake_case)]
22 pub fn #route_fn_name() -> ::spring_axum::Router {
23 #router_stmt
24 }
25 };
26 expanded.into()
27}
28
29struct RouteArgs {
31 path: LitStr,
32 layer: Option<Expr>,
33}
34
35impl Parse for RouteArgs {
36 fn parse(input: ParseStream) -> syn::Result<Self> {
37 let path: LitStr = input.parse()?;
38 let mut layer: Option<Expr> = None;
39 if input.peek(Token![,]) {
40 let _comma: Token![,] = input.parse()?;
41 let key: Ident = input.parse()?;
42 if key == "layer" {
43 let _eq: Token![=] = input.parse()?;
44 let expr: Expr = input.parse()?;
45 layer = Some(expr);
46 } else {
47 return Err(syn::Error::new(key.span(), "unsupported key, expected `layer`"));
48 }
49 }
50 Ok(Self { path, layer })
51 }
52}
53
54#[proc_macro_attribute]
55pub fn route_get(args: TokenStream, input: TokenStream) -> TokenStream {
56 let args = parse_macro_input!(args as RouteArgs);
57 let input = parse_macro_input!(input as ItemFn);
58 route_macro_impl("get", args.path, input, args.layer)
59}
60
61#[proc_macro_attribute]
62pub fn route_post(args: TokenStream, input: TokenStream) -> TokenStream {
63 let args = parse_macro_input!(args as RouteArgs);
64 let input = parse_macro_input!(input as ItemFn);
65 route_macro_impl("post", args.path, input, args.layer)
66}
67
68#[proc_macro_attribute]
69pub fn route_put(args: TokenStream, input: TokenStream) -> TokenStream {
70 let args = parse_macro_input!(args as RouteArgs);
71 let input = parse_macro_input!(input as ItemFn);
72 route_macro_impl("put", args.path, input, args.layer)
73}
74
75#[proc_macro_attribute]
76pub fn route_delete(args: TokenStream, input: TokenStream) -> TokenStream {
77 let args = parse_macro_input!(args as RouteArgs);
78 let input = parse_macro_input!(input as ItemFn);
79 route_macro_impl("delete", args.path, input, args.layer)
80}
81
82fn replace_type_path_to_validated(tp: TypePath, target: &str, replacement: &str) -> Option<Type> {
83 let last = tp.path.segments.last()?.ident.to_string();
85 if last == target {
86 if let syn::PathArguments::AngleBracketed(args) = &tp.path.segments.last()?.arguments {
88 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
89 let new_ty: Type = syn::parse_str(&format!("::spring_axum::{}<{}>", replacement, quote!(#inner_ty))).ok()?;
90 return Some(new_ty);
91 }
92 }
93 }
94 None
95}
96
97fn validate_transform(input: ItemFn, kind: &str) -> TokenStream {
98 let mut func = input.clone();
99 for arg in func.sig.inputs.iter_mut() {
100 if let FnArg::Typed(pt) = arg {
101 if let Type::Path(tp) = &*pt.ty {
102 let (target, repl) = match kind {
103 "json" => ("Json", "ValidatedJson"),
104 "query" => ("Query", "ValidatedQuery"),
105 "form" => ("Form", "ValidatedForm"),
106 "json_stream" => ("Json", "ValidatedJsonStream"),
107 _ => unreachable!(),
108 };
109 if let Some(new_ty) = replace_type_path_to_validated(tp.clone(), target, repl) {
110 pt.ty = Box::new(new_ty);
111 }
112 }
113 }
114 }
115 quote! { #func }.into()
116}
117
118#[proc_macro_attribute]
119pub fn validate_json(_args: TokenStream, input: TokenStream) -> TokenStream {
120 let input = parse_macro_input!(input as ItemFn);
121 validate_transform(input, "json")
122}
123
124#[proc_macro_attribute]
125pub fn validate_query(_args: TokenStream, input: TokenStream) -> TokenStream {
126 let input = parse_macro_input!(input as ItemFn);
127 validate_transform(input, "query")
128}
129
130#[proc_macro_attribute]
131pub fn validate_form(_args: TokenStream, input: TokenStream) -> TokenStream {
132 let input = parse_macro_input!(input as ItemFn);
133 validate_transform(input, "form")
134}
135
136#[proc_macro_attribute]
137pub fn validate_json_stream(_args: TokenStream, input: TokenStream) -> TokenStream {
138 let input = parse_macro_input!(input as ItemFn);
139 validate_transform(input, "json_stream")
140}
141
142struct IdentList(Punctuated<Ident, Token![,]>);
143impl Parse for IdentList {
144 fn parse(input: ParseStream) -> syn::Result<Self> {
145 Ok(Self(Punctuated::parse_terminated(input)?))
146 }
147}
148
149#[proc_macro_attribute]
150pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
151 let input_struct = parse_macro_input!(input as ItemStruct);
152 let idents = parse_macro_input!(args as IdentList).0;
153
154 let routes: Vec<Ident> = idents
155 .iter()
156 .map(|i| Ident::new(&format!("__spring_axum_route_{}", i), i.span()))
157 .collect();
158
159 let name = input_struct.ident.clone();
160 let mut merge_stmts = quote! { let mut router = ::spring_axum::Router::new(); };
161 for r in routes {
162 merge_stmts = quote! {
163 #merge_stmts
164 router = router.merge(#r());
165 };
166 }
167
168 let expanded = quote! {
169 #input_struct
170
171 impl ::spring_axum::Controller for #name {
172 fn routes(&self) -> ::spring_axum::Router {
173 #merge_stmts
174 router
175 }
176 }
177
178 inventory::submit!(::spring_axum::ControllerRouterRegistration {
180 init: || -> ::spring_axum::Router {
181 #merge_stmts
182 router
183 },
184 });
185 };
186 expanded.into()
187}
188
189#[proc_macro_attribute]
190pub fn component(_args: TokenStream, input: TokenStream) -> TokenStream {
191 let input_item = parse_macro_input!(input as ItemStruct);
192 let name = input_item.ident.clone();
193
194 let expanded = quote! {
196 #input_item
197
198 inventory::submit!(::spring_axum::ComponentRegistration {
199 init: |_: &::spring_axum::ApplicationContext| -> (::std::any::TypeId, Box<dyn ::std::any::Any + Send + Sync>) {
200 let value: #name = ::std::default::Default::default();
201 let arc: ::std::sync::Arc<#name> = ::std::sync::Arc::new(value);
202 (::std::any::TypeId::of::<#name>(), Box::new(arc))
203 },
204 });
205 };
206 expanded.into()
207}
208
209#[proc_macro_attribute]
210pub fn interceptor(_args: TokenStream, input: TokenStream) -> TokenStream {
211 let input_item = parse_macro_input!(input as ItemStruct);
212 let name = input_item.ident.clone();
213
214 let expanded = quote! {
215 #input_item
216
217 inventory::submit!(::spring_axum::InterceptorRegistration {
219 apply: |router: ::spring_axum::Router| -> ::spring_axum::Router {
220 router.layer(::spring_axum::InterceptorLayer::new(#name::default()))
221 },
222 });
223 };
224 expanded.into()
225}
226
227#[proc_macro_attribute]
230pub fn transactional(_args: TokenStream, input: TokenStream) -> TokenStream {
231 let mut func = parse_macro_input!(input as ItemFn);
232 let body = func.block.clone();
233 func.block = Box::new(syn::parse_quote!({
234 ::spring_axum::transaction(|| async move { #body }).await
235 }));
236 quote! { #func }.into()
237}
238
239#[proc_macro_attribute]
241pub fn non_tx(_args: TokenStream, input: TokenStream) -> TokenStream {
242 input
244}
245
246#[proc_macro_attribute]
249pub fn tx_service(_args: TokenStream, input: TokenStream) -> TokenStream {
250 let mut item_impl = parse_macro_input!(input as ItemImpl);
251
252 for impl_item in item_impl.items.iter_mut() {
253 if let ImplItem::Fn(method) = impl_item {
254 let has_non_tx = method.attrs.iter().any(|a| a.path().is_ident("non_tx"));
256 if has_non_tx {
257 method.attrs.retain(|a| !a.path().is_ident("non_tx"));
258 continue;
259 }
260
261 let body = method.block.clone();
263 method.block = syn::parse_quote!({
264 ::spring_axum::transaction(|| async move { #body }).await
265 });
266
267 if method.sig.asyncness.is_none() {
269 method.sig.asyncness = Some(syn::token::Async { span: method.sig.fn_token.span });
270 }
271 }
272 }
273
274 quote! { #item_impl }.into()
275}
276
277struct CacheArgs { ttl_secs: Option<u64> }
279impl Parse for CacheArgs {
280 fn parse(input: ParseStream) -> syn::Result<Self> {
281 if input.is_empty() { return Ok(Self { ttl_secs: None }); }
282 let key: Ident = input.parse()?;
283 if key != "ttl" { return Err(syn::Error::new(key.span(), "expected ttl = <secs>")); }
284 let _eq: Token![=] = input.parse()?;
285 let lit: syn::LitInt = input.parse()?;
286 let secs = lit.base10_parse::<u64>()?;
287 Ok(Self { ttl_secs: Some(secs) })
288 }
289}
290
291fn extract_app_result_inner_ty(fn_item: &ItemFn) -> Option<Type> {
292 if let syn::ReturnType::Type(_, ty_box) = &fn_item.sig.output {
293 if let Type::Path(tp) = &**ty_box {
294 if let Some(seg) = tp.path.segments.last() {
295 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
297 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
298 return Some(inner.clone());
299 }
300 }
301 }
302 }
303 }
304 None
305}
306
307#[proc_macro_attribute]
308pub fn cacheable(args: TokenStream, input: TokenStream) -> TokenStream {
309 let args = parse_macro_input!(args as CacheArgs);
310 let mut func = parse_macro_input!(input as ItemFn);
311 let fn_name = func.sig.ident.to_string();
312 let inner_ty = extract_app_result_inner_ty(&func).expect("cacheable requires return type AppResult<T>");
313 let body = func.block.clone();
314 let ttl_expr = if let Some(secs) = args.ttl_secs { quote! { Some(::std::time::Duration::from_secs(#secs)) } } else { quote! { None } };
315 let mut fields = Vec::<proc_macro2::TokenStream>::new();
317 for arg in func.sig.inputs.iter() {
318 if let FnArg::Typed(pt) = arg {
319 if let syn::Pat::Ident(pident) = &*pt.pat {
320 let ident = &pident.ident;
321 fields.push(quote! { stringify!(#ident) : &#ident });
322 }
323 }
324 }
325 let expanded_body = quote!({
326 let __args_json = ::serde_json::json!({ #(#fields),* });
327 let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
328 if let Some(__cached) = ::spring_axum::cache_instance().get_typed::<#inner_ty>(&__key) {
329 return Ok(__cached);
330 }
331 let __res: ::spring_axum::AppResult<#inner_ty> = (async move { #body }).await;
332 match __res {
333 Ok(__val) => {
334 ::spring_axum::cache_instance().put_typed(__key, __val.clone(), #ttl_expr);
335 Ok(__val)
336 }
337 Err(e) => Err(e),
338 }
339 });
340 func.block = Box::new(syn::parse_quote! { #expanded_body });
341 quote! { #func }.into()
342}
343
344#[proc_macro_attribute]
345pub fn cache_evict(_args: TokenStream, input: TokenStream) -> TokenStream {
346 let mut func = parse_macro_input!(input as ItemFn);
347 let fn_name = func.sig.ident.to_string();
348 let body = func.block.clone();
349 let mut fields = Vec::<proc_macro2::TokenStream>::new();
350 for arg in func.sig.inputs.iter() {
351 if let FnArg::Typed(pt) = arg {
352 if let syn::Pat::Ident(pident) = &*pt.pat {
353 let ident = &pident.ident;
354 fields.push(quote! { stringify!(#ident) : &#ident });
355 }
356 }
357 }
358 let expanded_body = quote!({
359 let __args_json = ::serde_json::json!({ #(#fields),* });
360 let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
361 ::spring_axum::cache_instance().evict(&__key);
362 (async move { #body }).await
363 });
364 func.block = Box::new(syn::parse_quote! { #expanded_body });
365 quote! { #func }.into()
366}
367
368#[proc_macro_attribute]
369pub fn cache_put(args: TokenStream, input: TokenStream) -> TokenStream {
370 let args = parse_macro_input!(args as CacheArgs);
371 let mut func = parse_macro_input!(input as ItemFn);
372 let fn_name = func.sig.ident.to_string();
373 let inner_ty = extract_app_result_inner_ty(&func).expect("cache_put requires return type AppResult<T>");
374 let body = func.block.clone();
375 let ttl_expr = if let Some(secs) = args.ttl_secs { quote! { Some(::std::time::Duration::from_secs(#secs)) } } else { quote! { None } };
376 let mut fields = Vec::<proc_macro2::TokenStream>::new();
377 for arg in func.sig.inputs.iter() {
378 if let FnArg::Typed(pt) = arg {
379 if let syn::Pat::Ident(pident) = &*pt.pat {
380 let ident = &pident.ident;
381 fields.push(quote! { stringify!(#ident) : &#ident });
382 }
383 }
384 }
385 let expanded_body = quote!({
386 let __args_json = ::serde_json::json!({ #(#fields),* });
387 let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
388 let __res: ::spring_axum::AppResult<#inner_ty> = (async move { #body }).await;
389 match __res {
390 Ok(__val) => {
391 ::spring_axum::cache_instance().put_typed(__key, __val.clone(), #ttl_expr);
392 Ok(__val)
393 }
394 Err(e) => Err(e),
395 }
396 });
397 func.block = Box::new(syn::parse_quote! { #expanded_body });
398 quote! { #func }.into()
399}
400
401struct EventTypePath(TypePath);
403impl Parse for EventTypePath {
404 fn parse(input: ParseStream) -> syn::Result<Self> { Ok(Self(input.parse()?)) }
405}
406
407#[proc_macro_attribute]
408pub fn event_listener(args: TokenStream, input: TokenStream) -> TokenStream {
409 let event_ty = parse_macro_input!(args as EventTypePath).0;
410 let func = parse_macro_input!(input as ItemFn);
411 let name = func.sig.ident.clone();
412 let expanded = quote! {
413 #func
414 inventory::submit!(::spring_axum::EventListenerRegistration {
415 type_id: ::std::any::TypeId::of::<#event_ty>(),
416 handle: |ev: &dyn ::std::any::Any, ctx: &::spring_axum::ApplicationContext| {
417 if let Some(e) = ev.downcast_ref::<#event_ty>() {
418 #name(e, ctx);
419 }
420 },
421 });
422 };
423 expanded.into()
424}
425
426#[proc_macro_attribute]
431pub fn sql(_args: TokenStream, input: TokenStream) -> TokenStream {
432 input
433}
434
435#[proc_macro_attribute]
439pub fn mapper(args: TokenStream, input: TokenStream) -> TokenStream {
440 #[derive(Default)]
442 struct MapperArgs { namespace: Option<String> }
443 impl Parse for MapperArgs {
444 fn parse(input: ParseStream) -> syn::Result<Self> {
445 if input.is_empty() { return Ok(MapperArgs::default()); }
446 let key: Ident = input.parse()?;
447 if key != "namespace" { return Err(syn::Error::new(key.span(), "expected `namespace`")); }
448 let _eq: Token![=] = input.parse()?;
449 let lit: LitStr = input.parse()?;
450 Ok(MapperArgs { namespace: Some(lit.value()) })
451 }
452 }
453 let parsed_args = parse_macro_input!(args as MapperArgs);
454 let mut item_impl = parse_macro_input!(input as syn::ItemImpl);
455
456 let ns = if let Some(ns) = parsed_args.namespace {
458 ns
459 } else {
460 match &*item_impl.self_ty {
462 Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()).unwrap_or_default(),
463 _ => String::new(),
464 }
465 };
466 let ns_lit = LitStr::new(&ns, proc_macro2::Span::call_site());
467
468 let mut new_items: Vec<syn::ImplItem> = Vec::new();
470 for it in item_impl.items.into_iter() {
471 if let syn::ImplItem::Fn(mut m) = it {
472 let has_sql = m.attrs.iter().any(|a| a.path().segments.last().map(|s| s.ident == "sql").unwrap_or(false));
473 if has_sql {
474 m.attrs.retain(|a| !a.path().segments.last().map(|s| s.ident == "sql").unwrap_or(false));
476
477 let mut param_idents: Vec<Ident> = Vec::new();
479 for arg in m.sig.inputs.iter() {
480 if let FnArg::Typed(pt) = arg {
481 if let syn::Pat::Ident(pi) = &*pt.pat {
482 param_idents.push(pi.ident.clone());
483 }
484 }
485 }
486
487 let json_pairs: Vec<proc_macro2::TokenStream> = param_idents
489 .iter()
490 .map(|id| {
491 let key = LitStr::new(&id.to_string(), id.span());
492 quote! { #key : #id }
493 })
494 .collect();
495
496 let method_name = m.sig.ident.clone();
497 let stmt_id = quote! { concat!(#ns_lit, ".", stringify!(#method_name)) };
498
499 m.block = syn::parse_quote!({
501 let params = ::serde_json::json!({ #(#json_pairs),* });
502 let exec = ::spring_axum_mybatis::NoopExecutor::default();
503 ::spring_axum::mybatis_exec!(exec, #stmt_id, params);
504 Ok(())
505 });
506
507 new_items.push(syn::ImplItem::Fn(m));
508 } else {
509 new_items.push(syn::ImplItem::Fn(m));
510 }
511 } else {
512 new_items.push(it);
513 }
514 }
515
516 item_impl.items = new_items;
517 quote! { #item_impl }.into()
518}