1use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input,
7 punctuated::Punctuated,
8 DeriveInput, Field, Fields, ItemMod, Token,
9};
10
11#[proc_macro_derive(Inject, attributes(inject))]
12pub fn inject_derive(input: TokenStream) -> TokenStream {
13 expand_inject(parse_macro_input!(input as DeriveInput))
14 .unwrap_or_else(|e| e.to_compile_error())
15 .into()
16}
17
18fn expand_inject(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
19 let name = &input.ident;
20 let named = match &input.data {
21 syn::Data::Struct(s) => match &s.fields {
22 Fields::Named(n) => n,
23 _ => return Err(syn::Error::new_spanned(name, "named fields")),
24 },
25 _ => return Err(syn::Error::new_spanned(name, "struct only")),
26 };
27 let fn_name = format_ident!("__rdi_construct_{}", name);
28 let mut inits = Vec::new();
29 for field in named.named.iter() {
30 let attrs = parse_ia(field);
31 let fnm = field.ident.as_ref().unwrap();
32 let (inner, _) = saw(&field.ty);
33 let init = if attrs.skip {
34 quote! {#fnm:Default::default()}
35 } else if attrs.provider {
36 quote! {#fnm:resolver.clone()}
37 } else if attrs.optional {
38 if let Some(k) = &attrs.key {
39 quote! {#fnm:resolver.get_keyed_any(::std::any::type_name::<#inner>(),#k).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d))}
40 } else {
41 quote! {#fnm:resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d))}
42 }
43 } else if let Some(k) = &attrs.key {
44 quote! {#fnm:resolver.get_keyed_any(::std::any::type_name::<#inner>(),#k).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d)).unwrap_or_else(||::std::panic!("keyed not found"))}
45 } else {
46 quote! {#fnm:resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a|a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d|::std::sync::Arc::clone(&*d)).unwrap_or_else(||::std::panic!("svc not registered"))}
47 };
48 inits.push(init);
49 }
50 Ok(
51 quote! {#[doc(hidden)]pub fn #fn_name(resolver:&dyn rust_dicore::IServiceResolver)->::std::sync::Arc<#name>{::std::sync::Arc::new(#name{#(#inits),*})}},
52 )
53}
54
55fn saw(ty: &syn::Type) -> (proc_macro2::TokenStream, bool) {
56 if let syn::Type::Path(p) = ty {
57 let l = p.path.segments.last().unwrap();
58 if l.ident == "Arc" {
59 if let syn::PathArguments::AngleBracketed(a) = &l.arguments {
60 if let Some(syn::GenericArgument::Type(i)) = a.args.first() {
61 return (quote! {#i}, true);
62 }
63 }
64 }
65 if l.ident == "Option" {
66 if let syn::PathArguments::AngleBracketed(a) = &l.arguments {
67 if let Some(syn::GenericArgument::Type(syn::Type::Path(ip))) = a.args.first() {
68 if ip
69 .path
70 .segments
71 .last()
72 .map(|s| s.ident == "Arc")
73 .unwrap_or(false)
74 {
75 if let syn::PathArguments::AngleBracketed(ia) =
76 &ip.path.segments.last().unwrap().arguments
77 {
78 if let Some(syn::GenericArgument::Type(t)) = ia.args.first() {
79 return (quote! {#t}, true);
80 }
81 }
82 }
83 }
84 }
85 }
86 }
87 (quote! {#ty}, false)
88}
89
90#[derive(Default)]
91struct IA {
92 skip: bool,
93 optional: bool,
94 provider: bool,
95 key: Option<String>,
96}
97fn parse_ia(f: &Field) -> IA {
98 let mut a = IA::default();
99 for attr in &f.attrs {
100 if !attr.path().is_ident("inject") {
101 continue;
102 }
103 let Ok(l) = attr.meta.require_list() else {
104 continue;
105 };
106 l.parse_nested_meta(|m| {
107 if m.path.is_ident("skip") {
108 a.skip = true;
109 } else if m.path.is_ident("optional") {
110 a.optional = true;
111 } else if m.path.is_ident("provider") {
112 a.provider = true;
113 } else if m.path.is_ident("key") {
114 a.key = Some(m.value()?.parse::<syn::LitStr>()?.value());
115 }
116 Ok(())
117 })
118 .ok();
119 }
120 a
121}
122
123#[proc_macro]
124pub fn inject(_: TokenStream) -> TokenStream {
125 quote! {}.into()
126}
127
128#[proc_macro_attribute]
129pub fn module(_: TokenStream, item: TokenStream) -> TokenStream {
130 expand_md(parse_macro_input!(item as ItemMod))
131 .unwrap_or_else(|e| e.to_compile_error())
132 .into()
133}
134
135fn expand_md(mut m: ItemMod) -> syn::Result<proc_macro2::TokenStream> {
136 let mn = m.ident.clone();
137 let fn_n = format_ident!("__rdi_build_provider_{}", mn);
138 let is = match &m.content {
139 Some((_, i)) => i.clone(),
140 None => return Err(syn::Error::new_spanned(m, "body required")),
141 };
142 let mut rs = Vec::new();
143 let mut cl = Vec::new();
144 for i in &is {
145 match i {
146 syn::Item::Macro(mc) => {
147 let ps = mc
148 .mac
149 .path
150 .segments
151 .iter()
152 .map(|s| s.ident.to_string())
153 .collect::<Vec<_>>()
154 .join("::");
155 if ps == "inject" || ps == "rust_dicore::inject" {
156 if let Ok(r) = syn::parse2::<ID>(mc.mac.tokens.clone()) {
157 rs.push(r);
158 }
159 } else {
160 cl.push(i.clone());
161 }
162 }
163 _ => cl.push(i.clone()),
164 }
165 }
166 let mut ch = Vec::new();
167 for r in &rs {
168 match &r.kind {
169 IK::N { lt, ty, imp } => {
170 let mt = lmt(*lt);
171 if let Some(imp_ty) = imp {
172 ch.push(quote!{ .#mt::<#ty>(|_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#imp_ty as ::std::default::Default>::default())) });
173 } else {
174 ch.push(quote!{ .#mt(|_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#ty as ::std::default::Default>::default())) });
175 }
176 }
177 IK::K { key, lt, ty } => {
178 let mt = kmt(*lt);
179 ch.push(quote!{ .#mt::<#ty>(#key,|_:&dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(<#ty as ::std::default::Default>::default())) });
180 }
181 IK::F { lt, f } => {
182 let mt = lmt(*lt);
183 ch.push(
184 quote! { .#mt(move |_: &dyn rust_dicore::IServiceResolver| ::std::sync::Arc::new(#f)) },
185 );
186 }
187 }
188 }
189 vd(&rs)?;
190 let bi: syn::Item = syn::parse2(quote! {
191 #[doc(hidden)]
192 pub fn #fn_n() -> ::std::result::Result<::std::sync::Arc<rust_dicore::ServiceProvider>, rust_dicore::RdiError> {
193 Ok(::std::sync::Arc::new(rust_dicore::ServiceCollection::new() #(#ch)* .build()?))
194 }
195 })
196 .unwrap();
197 cl.push(bi);
198 m.content = Some((syn::token::Brace::default(), cl));
199 Ok(quote! {#m})
200}
201fn lmt(lt: LT) -> proc_macro2::TokenStream {
202 match lt {
203 LT::S => quote! {singleton},
204 LT::Sc => quote! {scoped},
205 LT::T => quote! {transient},
206 }
207}
208fn kmt(lt: LT) -> proc_macro2::TokenStream {
209 match lt {
210 LT::S => quote! {keyed},
211 LT::Sc => quote! {keyed_scoped},
212 LT::T => quote! {keyed_transient},
213 }
214}
215fn vd(rs: &[ID]) -> syn::Result<()> {
216 let mut sn: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
217 for r in rs {
218 if let IK::K { key, .. } = &r.kind {
219 let e = sn.entry(key.clone()).or_default();
220 *e += 1;
221 if *e > 1 {
222 return Err(syn::Error::new(
223 proc_macro2::Span::call_site(),
224 format!("rdi-E004: duplicate key `{key}`"),
225 ));
226 }
227 }
228 }
229 Ok(())
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233enum LT {
234 S,
235 Sc,
236 T,
237}
238impl Parse for LT {
239 fn parse(i: ParseStream) -> syn::Result<Self> {
240 match i.parse::<syn::Ident>()?.to_string().as_str() {
241 "singleton" => Ok(LT::S),
242 "scoped" => Ok(LT::Sc),
243 "transient" => Ok(LT::T),
244 o => Err(syn::Error::new(i.span(), format!("unknown lifetime: {o}"))),
245 }
246 }
247}
248#[derive(Debug)]
249enum IK {
250 N {
251 lt: LT,
252 ty: syn::Type,
253 imp: Option<syn::Type>,
254 },
255 K {
256 key: String,
257 lt: LT,
258 ty: syn::Type,
259 },
260 F {
261 lt: LT,
262 f: syn::Expr,
263 },
264}
265#[derive(Debug)]
266struct ID {
267 kind: IK,
268}
269impl Parse for ID {
270 fn parse(i: ParseStream) -> syn::Result<Self> {
271 let mk = if i.peek(syn::Ident) && i.fork().parse::<syn::Ident>()? == "keyed" {
272 let _: syn::Ident = i.parse()?;
273 let k: syn::LitStr = i.parse()?;
274 let _: Token![:] = i.parse()?;
275 let lt: LT = i.parse()?;
276 Some((k.value(), lt))
277 } else {
278 None
279 };
280 if i.peek(syn::Ident) && i.fork().parse::<syn::Ident>()? == "factory" {
281 let _: syn::Ident = i.parse()?;
282 let lt: LT = i.parse()?;
283 let _: Token![:] = i.parse()?;
284 let _: syn::Type = i.parse()?;
285 let _: Token![=>] = i.parse()?;
286 let f: syn::Expr = i.parse()?;
287 return Ok(ID {
288 kind: IK::F { lt, f },
289 });
290 }
291 if let Some((k, l)) = mk {
292 let _: Token![:] = i.parse()?;
293 let ty: syn::Type = i.parse()?;
294 let _ = i.parse::<Token![=>]>();
295 if !i.is_empty() && !i.peek(Token![|]) {
296 let _: syn::Type = i.parse()?;
297 }
298 return Ok(ID {
299 kind: IK::K { key: k, lt: l, ty },
300 });
301 }
302 let lt: LT = i.parse()?;
303 let _: Token![:] = i.parse()?;
304 let ty: syn::Type = i.parse()?;
305 let _ = i.parse::<Token![=>]>();
306 let imp: Option<syn::Type> = if !i.is_empty() && !i.peek(Token![|]) {
307 Some(i.parse::<syn::Type>()?)
308 } else {
309 None
310 };
311 Ok(ID {
312 kind: IK::N { lt, ty, imp },
313 })
314 }
315}
316
317enum InjectArgs {
321 Plain {
322 lifetime: LT,
323 },
324 AsTrait {
325 lifetime: LT,
326 trait_ty: syn::Type,
327 },
328 AsTraits {
329 lifetime: LT,
330 trait_tys: Vec<syn::Type>,
331 },
332}
333
334impl Parse for InjectArgs {
335 fn parse(input: ParseStream) -> syn::Result<Self> {
336 let lt: LT = input.parse()?;
337
338 if input.peek(Token![,]) {
339 let _: Token![,] = input.parse()?;
340 if input.peek(Token![as]) {
341 let _: Token![as] = input.parse()?;
342 let _: Token![=] = input.parse()?;
343
344 if input.peek(syn::token::Bracket) {
345 let content;
346 let _ = syn::bracketed!(content in input);
347 let tys: Punctuated<syn::Type, Token![,]> =
348 content.parse_terminated(syn::Type::parse, Token![,])?;
349 return Ok(InjectArgs::AsTraits {
350 lifetime: lt,
351 trait_tys: tys.into_iter().collect(),
352 });
353 } else {
354 let ty: syn::Type = input.parse()?;
355 return Ok(InjectArgs::AsTrait {
356 lifetime: lt,
357 trait_ty: ty,
358 });
359 }
360 }
361 }
362
363 Ok(InjectArgs::Plain { lifetime: lt })
364 }
365}
366
367#[proc_macro_attribute]
368pub fn inject_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
369 expand_inject_attr(
370 parse_macro_input!(attr as InjectArgs),
371 parse_macro_input!(item as syn::Item),
372 )
373 .unwrap_or_else(|e| e.to_compile_error())
374 .into()
375}
376
377fn expand_inject_attr(args: InjectArgs, item: syn::Item) -> syn::Result<proc_macro2::TokenStream> {
378 let struct_item = match &item {
379 syn::Item::Struct(s) => s,
380 _ => return Err(syn::Error::new_spanned(&item, "only structs are supported")),
381 };
382
383 let name = &struct_item.ident;
384 let fn_name = format_ident!("__rdi_construct_{}", name);
385 let factory_name = format_ident!("__rdi_factory_{}", name);
386
387 let constructor_body = match &struct_item.fields {
388 syn::Fields::Named(n) => {
389 let mut inits = Vec::new();
390 for field in n.named.iter() {
391 let attrs = parse_ia(field);
392 let fnm = field.ident.as_ref().unwrap();
393 let (inner, _) = saw(&field.ty);
394
395 let init = if attrs.skip {
396 quote! { #fnm: ::std::default::Default::default() }
397 } else if attrs.provider {
398 quote! { #fnm: resolver.clone() }
399 } else if attrs.optional {
400 if let Some(k) = &attrs.key {
401 quote! { #fnm: resolver.get_keyed_any(::std::any::type_name::<#inner>(), #k).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)) }
402 } else {
403 quote! { #fnm: resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)) }
404 }
405 } else if let Some(k) = &attrs.key {
406 quote! { #fnm: resolver.get_keyed_any(::std::any::type_name::<#inner>(), #k).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)).unwrap_or_else(|| ::std::panic!("keyed not found")) }
407 } else {
408 quote! { #fnm: resolver.get_any(::std::any::type_name::<#inner>()).and_then(|a| a.downcast::<::std::sync::Arc<#inner>>().ok()).map(|d| ::std::sync::Arc::clone(&*d)).unwrap_or_else(|| ::std::panic!("svc not registered")) }
409 };
410 inits.push(init);
411 }
412 quote! { #name { #(#inits),* } }
413 }
414 syn::Fields::Unit => quote! { #name },
415 _ => {
416 return Err(syn::Error::new_spanned(
417 name,
418 "named struct or unit struct required",
419 ))
420 }
421 };
422
423 let constructor = quote! {
424 #[doc(hidden)]
425 #[allow(non_snake_case)]
426 pub fn #fn_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<#name> {
427 ::std::sync::Arc::new(#constructor_body)
428 }
429 };
430
431 let lt_ident = match args {
432 InjectArgs::Plain { lifetime: LT::S }
433 | InjectArgs::AsTrait {
434 lifetime: LT::S, ..
435 }
436 | InjectArgs::AsTraits {
437 lifetime: LT::S, ..
438 } => {
439 quote! { rust_dicore::ServiceLifetime::Singleton }
440 }
441 InjectArgs::Plain { lifetime: LT::Sc }
442 | InjectArgs::AsTrait {
443 lifetime: LT::Sc, ..
444 }
445 | InjectArgs::AsTraits {
446 lifetime: LT::Sc, ..
447 } => {
448 quote! { rust_dicore::ServiceLifetime::Scoped }
449 }
450 InjectArgs::Plain { lifetime: LT::T }
451 | InjectArgs::AsTrait {
452 lifetime: LT::T, ..
453 }
454 | InjectArgs::AsTraits {
455 lifetime: LT::T, ..
456 } => {
457 quote! { rust_dicore::ServiceLifetime::Transient }
458 }
459 };
460
461 let factory_fns: Vec<proc_macro2::TokenStream> = match &args {
463 InjectArgs::Plain { .. } => {
464 vec![quote! {
465 #[doc(hidden)]
466 #[allow(non_snake_case)]
467 fn #factory_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
468 let v: ::std::sync::Arc<#name> = #fn_name(resolver);
469 ::std::sync::Arc::new(v)
470 as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
471 }
472 }]
473 }
474 InjectArgs::AsTrait { trait_ty, .. } => {
475 vec![quote! {
476 #[doc(hidden)]
477 #[allow(non_snake_case)]
478 fn #factory_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
479 let v: ::std::sync::Arc<#name> = #fn_name(resolver);
480 let v2: ::std::sync::Arc<#trait_ty> = v;
481 ::std::sync::Arc::new(v2)
482 as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
483 }
484 }]
485 }
486 InjectArgs::AsTraits { trait_tys, .. } => {
487 trait_tys.iter().enumerate().map(|(i, trait_ty)| {
488 let fn_name = if i == 0 {
489 factory_name.clone()
490 } else {
491 format_ident!("__rdi_factory_{}_{}", name, i)
492 };
493 quote! {
494 #[doc(hidden)]
495 #[allow(non_snake_case)]
496 fn #fn_name(resolver: &dyn rust_dicore::IServiceResolver) -> ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
497 let v: ::std::sync::Arc<#name> = #fn_name(resolver);
498 let v2: ::std::sync::Arc<#trait_ty> = v;
499 ::std::sync::Arc::new(v2)
500 as ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>
501 }
502 }
503 }).collect()
504 }
505 };
506
507 let type_name_fn_name = format_ident!("__rdi_type_name_{}", name);
508
509 let (type_name_helper, trait_tys_for_subs): (proc_macro2::TokenStream, Vec<syn::Type>) =
510 match &args {
511 InjectArgs::Plain { .. } => {
512 let helper = quote! {
513 #[doc(hidden)]
514 #[allow(non_snake_case)]
515 fn #type_name_fn_name() -> &'static str {
516 ::std::any::type_name::<#name>()
517 }
518 };
519 (helper, vec![])
520 }
521 InjectArgs::AsTrait { trait_ty, .. } => {
522 let helper = quote! {
523 #[doc(hidden)]
524 #[allow(non_snake_case)]
525 fn #type_name_fn_name() -> &'static str {
526 ::std::any::type_name::<#trait_ty>()
527 }
528 };
529 (helper, vec![trait_ty.clone()])
530 }
531 InjectArgs::AsTraits { trait_tys, .. } => {
532 let first = &trait_tys[0];
533 let helper = quote! {
534 #[doc(hidden)]
535 #[allow(non_snake_case)]
536 fn #type_name_fn_name() -> &'static str {
537 ::std::any::type_name::<#first>()
538 }
539 };
540 let extra: Vec<proc_macro2::TokenStream> = trait_tys[1..]
542 .iter()
543 .enumerate()
544 .map(|(i, ty)| {
545 let hn = format_ident!("__rdi_type_name_{}_{}", name, i + 1);
546 quote! {
547 #[doc(hidden)]
548 #[allow(non_snake_case)]
549 fn #hn() -> &'static str {
550 ::std::any::type_name::<#ty>()
551 }
552 }
553 })
554 .collect();
555 let all_helpers = quote! {
556 #helper
557 #(#extra)*
558 };
559 (all_helpers, trait_tys.clone())
560 }
561 };
562
563 let submissions = match &args {
564 InjectArgs::Plain { .. } => {
565 quote! {
566 rust_dicore::inventory::submit! {
567 rust_dicore::ServiceRegistration {
568 lifetime: #lt_ident,
569 type_id: ::std::any::TypeId::of::<#name>(),
570 type_name_fn: #type_name_fn_name,
571 factory: #factory_name,
572 }
573 }
574 }
575 }
576 InjectArgs::AsTrait { trait_ty, .. } => {
577 quote! {
578 rust_dicore::inventory::submit! {
579 rust_dicore::ServiceRegistration {
580 lifetime: #lt_ident,
581 type_id: ::std::any::TypeId::of::<#trait_ty>(),
582 type_name_fn: #type_name_fn_name,
583 factory: #factory_name,
584 }
585 }
586 }
587 }
588 InjectArgs::AsTraits { .. } => {
589 let mut subs = Vec::new();
590 for (i, trait_ty) in trait_tys_for_subs.iter().enumerate() {
591 let helper = if i == 0 {
592 type_name_fn_name.clone()
593 } else {
594 format_ident!("__rdi_type_name_{}_{}", name, i)
595 };
596 subs.push(quote! {
597 rust_dicore::inventory::submit! {
598 rust_dicore::ServiceRegistration {
599 lifetime: #lt_ident,
600 type_id: ::std::any::TypeId::of::<#trait_ty>(),
601 type_name_fn: #helper,
602 factory: #factory_name,
603 }
604 }
605 });
606 }
607 quote! { #(#subs)* }
608 }
609 };
610
611 Ok(quote! {
612 #item
613 #constructor
614 #type_name_helper
615 #(#factory_fns)*
616 #submissions
617 })
618}