Skip to main content

rust_dicore_macros/
lib.rs

1//! rust-di procedural macros.
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input,
7    punctuated::Punctuated,
8    DeriveInput, Field, Fields, ItemMod, Token,
9};
10
11#[proc_macro_derive(Inject, attributes(inject))]
12pub fn inject_derive(input: TokenStream) -> TokenStream {
13    expand_inject(parse_macro_input!(input as DeriveInput))
14        .unwrap_or_else(|e| e.to_compile_error())
15        .into()
16}
17
18fn expand_inject(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
19    let name = &input.ident;
20    let named = match &input.data {
21        syn::Data::Struct(s) => match &s.fields {
22            Fields::Named(n) => n,
23            _ => return Err(syn::Error::new_spanned(name, "named fields")),
24        },
25        _ => return Err(syn::Error::new_spanned(name, "struct only")),
26    };
27    let fn_name = format_ident!("__rdi_construct_{}", name);
28    let mut inits = Vec::new();
29    for field in named.named.iter() {
30        let attrs = parse_ia(field);
31        let fnm = field.ident.as_ref().unwrap();
32        let (inner, _) = saw(&field.ty);
33        let init = if attrs.skip {
34            quote! {#fnm:Default::default()}
35        } else if attrs.provider {
36            quote! {#fnm:resolver.clone()}
37        } else if attrs.optional {
38            if let Some(k) = &attrs.key {
39                quote! {#fnm:resolver.get_keyed_any(::std::any::type_name::<#inner>(),#k).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d))}
40            } else {
41                quote! {#fnm:resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d))}
42            }
43        } else if let Some(k) = &attrs.key {
44            quote! {#fnm:resolver.get_keyed_any(::std::any::type_name::<#inner>(),#k).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d)).unwrap_or_else(||::std::panic!("keyed not found"))}
45        } else {
46            quote! {#fnm:resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d)).unwrap_or_else(||::std::panic!("svc not registered"))}
47        };
48        inits.push(init);
49    }
50    Ok(
51        quote! {#[doc(hidden)]pub fn #fn_name(resolver:&dyn rust_dicore::IServiceResolver)->::std::sync::Arc<#name>{::std::sync::Arc::new(#name{#(#inits),*})}},
52    )
53}
54
55fn saw(ty: &syn::Type) -> (proc_macro2::TokenStream, bool) {
56    if let syn::Type::Path(p) = ty {
57        let l = p.path.segments.last().unwrap();
58        if l.ident == "Arc" {
59            if let syn::PathArguments::AngleBracketed(a) = &l.arguments {
60                if let Some(syn::GenericArgument::Type(i)) = a.args.first() {
61                    return (quote! {#i}, true);
62                }
63            }
64        }
65        if l.ident == "Option" {
66            if let syn::PathArguments::AngleBracketed(a) = &l.arguments {
67                if let Some(syn::GenericArgument::Type(syn::Type::Path(ip))) = a.args.first() {
68                    if ip
69                        .path
70                        .segments
71                        .last()
72                        .map(|s| s.ident == "Arc")
73                        .unwrap_or(false)
74                    {
75                        if let syn::PathArguments::AngleBracketed(ia) =
76                            &ip.path.segments.last().unwrap().arguments
77                        {
78                            if let Some(syn::GenericArgument::Type(t)) = ia.args.first() {
79                                return (quote! {#t}, true);
80                            }
81                        }
82                    }
83                }
84            }
85        }
86    }
87    (quote! {#ty}, false)
88}
89
90#[derive(Default)]
91struct IA {
92    skip: bool,
93    optional: bool,
94    provider: bool,
95    key: Option<String>,
96}
97fn parse_ia(f: &Field) -> IA {
98    let mut a = IA::default();
99    for attr in &f.attrs {
100        if !attr.path().is_ident("inject") {
101            continue;
102        }
103        let Ok(l) = attr.meta.require_list() else {
104            continue;
105        };
106        l.parse_nested_meta(|m| {
107            if m.path.is_ident("skip") {
108                a.skip = true;
109            } else if m.path.is_ident("optional") {
110                a.optional = true;
111            } else if m.path.is_ident("provider") {
112                a.provider = true;
113            } else if m.path.is_ident("key") {
114                a.key = Some(m.value()?.parse::<syn::LitStr>()?.value());
115            }
116            Ok(())
117        })
118        .ok();
119    }
120    a
121}
122
123#[proc_macro]
124pub fn inject(_: TokenStream) -> TokenStream {
125    quote! {}.into()
126}
127
128#[proc_macro_attribute]
129pub fn module(_: TokenStream, item: TokenStream) -> TokenStream {
130    expand_md(parse_macro_input!(item as ItemMod))
131        .unwrap_or_else(|e| e.to_compile_error())
132        .into()
133}
134
135fn expand_md(mut m: ItemMod) -> syn::Result<proc_macro2::TokenStream> {
136    let mn = m.ident.clone();
137    let fn_n = format_ident!("__rdi_build_provider_{}", mn);
138    let is = match &m.content {
139        Some((_, i)) => i.clone(),
140        None => return Err(syn::Error::new_spanned(m, "body required")),
141    };
142    let mut rs = Vec::new();
143    let mut cl = Vec::new();
144    for i in &is {
145        match i {
146            syn::Item::Macro(mc) => {
147                let ps = mc
148                    .mac
149                    .path
150                    .segments
151                    .iter()
152                    .map(|s| s.ident.to_string())
153                    .collect::<Vec<_>>()
154                    .join("::");
155                if ps == "inject" || ps == "rust_dicore::inject" {
156                    if let Ok(r) = syn::parse2::<ID>(mc.mac.tokens.clone()) {
157                        rs.push(r);
158                    }
159                } else {
160                    cl.push(i.clone());
161                }
162            }
163            _ => cl.push(i.clone()),
164        }
165    }
166    let mut ch = Vec::new();
167    for r in &rs {
168        match &r.kind {
169            IK::N { lt, ty, imp } => {
170                let mt = lmt(*lt);
171                if let Some(imp_ty) = imp {
172                    ch.push(quote!{ .#mt::<#ty>(|_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#imp_ty as ::std::default::Default>::default())) });
173                } else {
174                    ch.push(quote!{ .#mt(|_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#ty as ::std::default::Default>::default())) });
175                }
176            }
177            IK::K { key, lt, ty } => {
178                let mt = kmt(*lt);
179                ch.push(quote!{ .#mt::<#ty>(#key,|_:&dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#ty as ::std::default::Default>::default())) });
180            }
181            IK::F { lt, f } => {
182                let mt = lmt(*lt);
183                ch.push(
184                    quote! { .#mt(move |_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(#f)) },
185                );
186            }
187        }
188    }
189    vd(&rs)?;
190    let bi: syn::Item = syn::parse2(quote! {
191        #[doc(hidden)]
192        pub fn #fn_n() -> ::std::result::Result<::std::sync::Arc<rust_dicore::ServiceProvider>, rust_dicore::RdiError> {
193            Ok(::std::sync::Arc::new(rust_dicore::ServiceCollection::new() #(#ch)* .build()?))
194        }
195    })
196    .unwrap();
197    cl.push(bi);
198    m.content = Some((syn::token::Brace::default(), cl));
199    Ok(quote! {#m})
200}
201fn lmt(lt: LT) -> proc_macro2::TokenStream {
202    match lt {
203        LT::S => quote! {singleton},
204        LT::Sc => quote! {scoped},
205        LT::T => quote! {transient},
206    }
207}
208fn kmt(lt: LT) -> proc_macro2::TokenStream {
209    match lt {
210        LT::S => quote! {keyed},
211        LT::Sc => quote! {keyed_scoped},
212        LT::T => quote! {keyed_transient},
213    }
214}
215fn vd(rs: &[ID]) -> syn::Result<()> {
216    let mut sn: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
217    for r in rs {
218        if let IK::K { key, .. } = &r.kind {
219            let e = sn.entry(key.clone()).or_default();
220            *e += 1;
221            if *e > 1 {
222                return Err(syn::Error::new(
223                    proc_macro2::Span::call_site(),
224                    format!("rdi-E004: duplicate key `{key}`"),
225                ));
226            }
227        }
228    }
229    Ok(())
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233enum LT {
234    S,
235    Sc,
236    T,
237}
238impl Parse for LT {
239    fn parse(i: ParseStream) -> syn::Result<Self> {
240        match i.parse::<syn::Ident>()?.to_string().as_str() {
241            "singleton" => Ok(LT::S),
242            "scoped" => Ok(LT::Sc),
243            "transient" => Ok(LT::T),
244            o => Err(syn::Error::new(i.span(), format!("unknown lifetime: {o}"))),
245        }
246    }
247}
248#[derive(Debug)]
249enum IK {
250    N {
251        lt: LT,
252        ty: syn::Type,
253        imp: Option<syn::Type>,
254    },
255    K {
256        key: String,
257        lt: LT,
258        ty: syn::Type,
259    },
260    F {
261        lt: LT,
262        f: syn::Expr,
263    },
264}
265#[derive(Debug)]
266struct ID {
267    kind: IK,
268}
269impl Parse for ID {
270    fn parse(i: ParseStream) -> syn::Result<Self> {
271        let mk = if i.peek(syn::Ident) && i.fork().parse::<syn::Ident>()? == "keyed" {
272            let _: syn::Ident = i.parse()?;
273            let k: syn::LitStr = i.parse()?;
274            let _: Token![:] = i.parse()?;
275            let lt: LT = i.parse()?;
276            Some((k.value(), lt))
277        } else {
278            None
279        };
280        if i.peek(syn::Ident) && i.fork().parse::<syn::Ident>()? == "factory" {
281            let _: syn::Ident = i.parse()?;
282            let lt: LT = i.parse()?;
283            let _: Token![:] = i.parse()?;
284            let _: syn::Type = i.parse()?;
285            let _: Token![=>] = i.parse()?;
286            let f: syn::Expr = i.parse()?;
287            return Ok(ID {
288                kind: IK::F { lt, f },
289            });
290        }
291        if let Some((k, l)) = mk {
292            let _: Token![:] = i.parse()?;
293            let ty: syn::Type = i.parse()?;
294            let _ = i.parse::<Token![=>]>();
295            if !i.is_empty() && !i.peek(Token![|]) {
296                let _: syn::Type = i.parse()?;
297            }
298            return Ok(ID {
299                kind: IK::K { key: k, lt: l, ty },
300            });
301        }
302        let lt: LT = i.parse()?;
303        let _: Token![:] = i.parse()?;
304        let ty: syn::Type = i.parse()?;
305        let _ = i.parse::<Token![=>]>();
306        let imp: Option<syn::Type> = if !i.is_empty() && !i.peek(Token![|]) {
307            Some(i.parse::<syn::Type>()?)
308        } else {
309            None
310        };
311        Ok(ID {
312            kind: IK::N { lt, ty, imp },
313        })
314    }
315}
316
317// ── `#[rust_dicore::inject(...)]` attribute macro ──
318
319/// Parsed arguments of `#[rust_dicore::inject(...)]`
320enum InjectArgs {
321    Plain {
322        lifetime: LT,
323    },
324    AsTrait {
325        lifetime: LT,
326        trait_ty: syn::Type,
327    },
328    AsTraits {
329        lifetime: LT,
330        trait_tys: Vec<syn::Type>,
331    },
332}
333
334impl Parse for InjectArgs {
335    fn parse(input: ParseStream) -> syn::Result<Self> {
336        let lt: LT = input.parse()?;
337
338        if input.peek(Token![,]) {
339            let _: Token![,] = input.parse()?;
340            if input.peek(Token![as]) {
341                let _: Token![as] = input.parse()?;
342                let _: Token![=] = input.parse()?;
343
344                if input.peek(syn::token::Bracket) {
345                    let content;
346                    let _ = syn::bracketed!(content in input);
347                    let tys: Punctuated<syn::Type, Token![,]> =
348                        content.parse_terminated(syn::Type::parse, Token![,])?;
349                    return Ok(InjectArgs::AsTraits {
350                        lifetime: lt,
351                        trait_tys: tys.into_iter().collect(),
352                    });
353                } else {
354                    let ty: syn::Type = input.parse()?;
355                    return Ok(InjectArgs::AsTrait {
356                        lifetime: lt,
357                        trait_ty: ty,
358                    });
359                }
360            }
361        }
362
363        Ok(InjectArgs::Plain { lifetime: lt })
364    }
365}
366
367#[proc_macro_attribute]
368pub fn inject_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
369    expand_inject_attr(
370        parse_macro_input!(attr as InjectArgs),
371        parse_macro_input!(item as syn::Item),
372    )
373    .unwrap_or_else(|e| e.to_compile_error())
374    .into()
375}
376
377fn expand_inject_attr(args: InjectArgs, item: syn::Item) -> syn::Result<proc_macro2::TokenStream> {
378    let struct_item = match &item {
379        syn::Item::Struct(s) => s,
380        _ => return Err(syn::Error::new_spanned(&item, "only structs are supported")),
381    };
382
383    let name = &struct_item.ident;
384    let fn_name = format_ident!("__rdi_construct_{}", name);
385    let factory_name = format_ident!("__rdi_factory_{}", name);
386
387    let constructor_body = match &struct_item.fields {
388        syn::Fields::Named(n) => {
389            let mut inits = Vec::new();
390            for field in n.named.iter() {
391                let attrs = parse_ia(field);
392                let fnm = field.ident.as_ref().unwrap();
393                let (inner, _) = saw(&field.ty);
394
395                let init = if attrs.skip {
396                    quote! { #fnm: ::std::default::Default::default() }
397                } else if attrs.provider {
398                    quote! { #fnm: resolver.clone() }
399                } else if attrs.optional {
400                    if let Some(k) = &attrs.key {
401                        quote! { #fnm: resolver.get_keyed_any(::std::any::type_name::<#inner>(), #k).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)) }
402                    } else {
403                        quote! { #fnm: resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)) }
404                    }
405                } else if let Some(k) = &attrs.key {
406                    quote! { #fnm: resolver.get_keyed_any(::std::any::type_name::<#inner>(), #k).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)).unwrap_or_else(|| ::std::panic!("keyed not found")) }
407                } else {
408                    quote! { #fnm: resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)).unwrap_or_else(|| ::std::panic!("svc not registered")) }
409                };
410                inits.push(init);
411            }
412            quote! { #name { #(#inits),* } }
413        }
414        syn::Fields::Unit => quote! { #name },
415        _ => {
416            return Err(syn::Error::new_spanned(
417                name,
418                "named struct or unit struct required",
419            ))
420        }
421    };
422
423    let constructor = quote! {
424        #[doc(hidden)]
425        #[allow(non_snake_case)]
426        pub fn #fn_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<#name> {
427            ::std::sync::Arc::new(#constructor_body)
428        }
429    };
430
431    let lt_ident = match args {
432        InjectArgs::Plain { lifetime: LT::S }
433        | InjectArgs::AsTrait {
434            lifetime: LT::S, ..
435        }
436        | InjectArgs::AsTraits {
437            lifetime: LT::S, ..
438        } => {
439            quote! { rust_dicore::ServiceLifetime::Singleton }
440        }
441        InjectArgs::Plain { lifetime: LT::Sc }
442        | InjectArgs::AsTrait {
443            lifetime: LT::Sc, ..
444        }
445        | InjectArgs::AsTraits {
446            lifetime: LT::Sc, ..
447        } => {
448            quote! { rust_dicore::ServiceLifetime::Scoped }
449        }
450        InjectArgs::Plain { lifetime: LT::T }
451        | InjectArgs::AsTrait {
452            lifetime: LT::T, ..
453        }
454        | InjectArgs::AsTraits {
455            lifetime: LT::T, ..
456        } => {
457            quote! { rust_dicore::ServiceLifetime::Transient }
458        }
459    };
460
461    // Generate factory function(s) — for AsTrait/AsTraits, upcast to dyn Trait first
462    let factory_fns: Vec<proc_macro2::TokenStream> = match &args {
463        InjectArgs::Plain { .. } => {
464            vec![quote! {
465                #[doc(hidden)]
466                #[allow(non_snake_case)]
467                fn #factory_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
468                    let v: ::std::sync::Arc<#name> = #fn_name(resolver);
469                    ::std::sync::Arc::new(v)
470                        as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
471                }
472            }]
473        }
474        InjectArgs::AsTrait { trait_ty, .. } => {
475            vec![quote! {
476                #[doc(hidden)]
477                #[allow(non_snake_case)]
478                fn #factory_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
479                    let v: ::std::sync::Arc<#name> = #fn_name(resolver);
480                    let v2: ::std::sync::Arc<#trait_ty> = v;
481                    ::std::sync::Arc::new(v2)
482                        as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
483                }
484            }]
485        }
486        InjectArgs::AsTraits { trait_tys, .. } => {
487            trait_tys.iter().enumerate().map(|(i, trait_ty)| {
488                let fn_name = if i == 0 {
489                    factory_name.clone()
490                } else {
491                    format_ident!("__rdi_factory_{}_{}", name, i)
492                };
493                quote! {
494                    #[doc(hidden)]
495                    #[allow(non_snake_case)]
496                    fn #fn_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
497                        let v: ::std::sync::Arc<#name> = #fn_name(resolver);
498                        let v2: ::std::sync::Arc<#trait_ty> = v;
499                        ::std::sync::Arc::new(v2)
500                            as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
501                    }
502                }
503            }).collect()
504        }
505    };
506
507    let type_name_fn_name = format_ident!("__rdi_type_name_{}", name);
508
509    let (type_name_helper, trait_tys_for_subs): (proc_macro2::TokenStream, Vec<syn::Type>) =
510        match &args {
511            InjectArgs::Plain { .. } => {
512                let helper = quote! {
513                    #[doc(hidden)]
514                    #[allow(non_snake_case)]
515                    fn #type_name_fn_name() -> &'static str {
516                        ::std::any::type_name::<#name>()
517                    }
518                };
519                (helper, vec![])
520            }
521            InjectArgs::AsTrait { trait_ty, .. } => {
522                let helper = quote! {
523                    #[doc(hidden)]
524                    #[allow(non_snake_case)]
525                    fn #type_name_fn_name() -> &'static str {
526                        ::std::any::type_name::<#trait_ty>()
527                    }
528                };
529                (helper, vec![trait_ty.clone()])
530            }
531            InjectArgs::AsTraits { trait_tys, .. } => {
532                let first = &trait_tys[0];
533                let helper = quote! {
534                    #[doc(hidden)]
535                    #[allow(non_snake_case)]
536                    fn #type_name_fn_name() -> &'static str {
537                        ::std::any::type_name::<#first>()
538                    }
539                };
540                // Generate extra helpers for remaining traits
541                let extra: Vec<proc_macro2::TokenStream> = trait_tys[1..]
542                    .iter()
543                    .enumerate()
544                    .map(|(i, ty)| {
545                        let hn = format_ident!("__rdi_type_name_{}_{}", name, i + 1);
546                        quote! {
547                            #[doc(hidden)]
548                            #[allow(non_snake_case)]
549                            fn #hn() -> &'static str {
550                                ::std::any::type_name::<#ty>()
551                            }
552                        }
553                    })
554                    .collect();
555                let all_helpers = quote! {
556                    #helper
557                    #(#extra)*
558                };
559                (all_helpers, trait_tys.clone())
560            }
561        };
562
563    let submissions = match &args {
564        InjectArgs::Plain { .. } => {
565            quote! {
566                rust_dicore::inventory::submit! {
567                    rust_dicore::ServiceRegistration {
568                        lifetime: #lt_ident,
569                        type_id: ::std::any::TypeId::of::<#name>(),
570                        type_name_fn: #type_name_fn_name,
571                        factory: #factory_name,
572                    }
573                }
574            }
575        }
576        InjectArgs::AsTrait { trait_ty, .. } => {
577            quote! {
578                rust_dicore::inventory::submit! {
579                    rust_dicore::ServiceRegistration {
580                        lifetime: #lt_ident,
581                        type_id: ::std::any::TypeId::of::<#trait_ty>(),
582                        type_name_fn: #type_name_fn_name,
583                        factory: #factory_name,
584                    }
585                }
586            }
587        }
588        InjectArgs::AsTraits { .. } => {
589            let mut subs = Vec::new();
590            for (i, trait_ty) in trait_tys_for_subs.iter().enumerate() {
591                let helper = if i == 0 {
592                    type_name_fn_name.clone()
593                } else {
594                    format_ident!("__rdi_type_name_{}_{}", name, i)
595                };
596                subs.push(quote! {
597                    rust_dicore::inventory::submit! {
598                        rust_dicore::ServiceRegistration {
599                            lifetime: #lt_ident,
600                            type_id: ::std::any::TypeId::of::<#trait_ty>(),
601                            type_name_fn: #helper,
602                            factory: #factory_name,
603                        }
604                    }
605                });
606            }
607            quote! { #(#subs)* }
608        }
609    };
610
611    Ok(quote! {
612        #item
613        #constructor
614        #type_name_helper
615        #(#factory_fns)*
616        #submissions
617    })
618}