Skip to main content

studiole_di_macros/
lib.rs

1//! Derive macros for `studiole-di`.
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, parse_macro_input};
5
6/// Derive macro for implementing the `Service` trait.
7///
8/// Generates a `Service` implementation that resolves all fields via
9/// `ServiceProvider::get_service()`.
10///
11/// # Example
12///
13/// ```ignore
14/// #[derive(Service)]
15/// pub struct MyHandler {
16///     http: Arc<HttpClient>,
17///     metadata: Arc<MetadataRepository>,
18/// }
19/// ```
20///
21/// Generates:
22///
23/// ```ignore
24/// impl Service for MyHandler {
25///     type Error = ServiceError;
26///
27///     async fn from_services(
28///         services: &ServiceProvider,
29///     ) -> Result<Self, Report<Self::Error>> {
30///         Ok(Self {
31///             http: services.get_service().await?,
32///             metadata: services.get_service().await?,
33///         })
34///     }
35/// }
36/// ```
37#[proc_macro_derive(Service)]
38pub fn derive_service(input: TokenStream) -> TokenStream {
39    let input = parse_macro_input!(input as DeriveInput);
40    let name = &input.ident;
41
42    let fields = match &input.data {
43        Data::Struct(data) => match &data.fields {
44            Fields::Named(fields) => fields,
45            _ => {
46                return syn::Error::new_spanned(
47                    &input,
48                    "Service derive only supports structs with named fields",
49                )
50                .to_compile_error()
51                .into();
52            }
53        },
54        _ => {
55            return syn::Error::new_spanned(&input, "Service derive only supports structs")
56                .to_compile_error()
57                .into();
58        }
59    };
60
61    let field_names: Vec<_> = fields
62        .named
63        .iter()
64        .filter_map(|f| f.ident.as_ref())
65        .collect();
66
67    let expanded = quote! {
68        impl Service for #name {
69            type Error = ServiceError;
70
71            async fn from_services(
72                services: &ServiceProvider,
73            ) -> Result<Self, Report<Self::Error>> {
74                Ok(Self {
75                    #(#field_names: services.get_service().await?,)*
76                })
77            }
78        }
79    };
80
81    TokenStream::from(expanded)
82}