Skip to main content

rust_config_tree_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DeriveInput, Error, Fields, GenericArgument, LitStr, PathArguments, Type,
5    parse_macro_input,
6};
7
8#[proc_macro_derive(ConfigOverrides, attributes(config_override))]
9pub fn derive_config_overrides(input: TokenStream) -> TokenStream {
10    match expand_config_overrides(parse_macro_input!(input as DeriveInput)) {
11        Ok(tokens) => tokens.into(),
12        Err(err) => err.to_compile_error().into(),
13    }
14}
15
16#[proc_macro_derive(ConfigSchema, attributes(config_schema))]
17pub fn derive_config_schema(input: TokenStream) -> TokenStream {
18    match expand_config_schema(parse_macro_input!(input as DeriveInput)) {
19        Ok(tokens) => tokens.into(),
20        Err(err) => err.to_compile_error().into(),
21    }
22}
23
24fn expand_config_overrides(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
25    let name = input.ident;
26    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
27    let fields = match input.data {
28        Data::Struct(data) => match data.fields {
29            Fields::Named(fields) => fields.named,
30            _ => {
31                return Err(Error::new_spanned(
32                    name,
33                    "ConfigOverrides only supports structs with named fields",
34                ));
35            }
36        },
37        _ => {
38            return Err(Error::new_spanned(
39                name,
40                "ConfigOverrides only supports structs",
41            ));
42        }
43    };
44
45    let mut inserts = Vec::new();
46    for field in fields {
47        let Some(path) = override_path(&field.attrs)? else {
48            continue;
49        };
50        let ident = field.ident.ok_or_else(|| {
51            Error::new_spanned(&field.ty, "config_override must be used on a named field")
52        })?;
53
54        if option_inner(&field.ty).is_some() {
55            inserts.push(quote! {
56                if let Some(value) = &self.#ident {
57                    provider.insert(#path, value)?;
58                }
59            });
60        } else {
61            inserts.push(quote! {
62                provider.insert(#path, &self.#ident)?;
63            });
64        }
65    }
66
67    Ok(quote! {
68        impl #impl_generics ::rust_config_tree::cli::ConfigOverrides for #name #ty_generics #where_clause {
69            fn config_overrides(
70                &self,
71            ) -> ::rust_config_tree::config::ConfigResult<::rust_config_tree::cli::ConfigOverrideProvider> {
72                let mut provider = ::rust_config_tree::cli::ConfigOverrideProvider::new();
73                #(#inserts)*
74                Ok(provider)
75            }
76        }
77    })
78}
79
80fn override_path(attrs: &[Attribute]) -> syn::Result<Option<LitStr>> {
81    let mut path = None;
82
83    for attr in attrs {
84        if !attr.path().is_ident("config_override") {
85            continue;
86        }
87
88        if path.is_some() {
89            return Err(Error::new_spanned(
90                attr,
91                "config_override cannot be repeated on the same field",
92            ));
93        }
94
95        let parsed_path = parse_override_path(attr)?;
96        validate_path(&parsed_path)?;
97        path = Some(parsed_path);
98    }
99
100    Ok(path)
101}
102
103fn parse_override_path(attr: &Attribute) -> syn::Result<LitStr> {
104    if let Ok(path) = attr.parse_args::<LitStr>() {
105        return Ok(path);
106    }
107
108    let mut path = None;
109    attr.parse_nested_meta(|meta| {
110        if !meta.path.is_ident("path") {
111            return Err(meta.error("config_override only supports the path argument"));
112        }
113        let value = meta.value()?;
114        let lit = value.parse::<LitStr>()?;
115        path = Some(lit);
116        Ok(())
117    })?;
118
119    path.ok_or_else(|| Error::new_spanned(attr, "config_override requires a path argument"))
120}
121
122fn validate_path(path: &LitStr) -> syn::Result<()> {
123    let value = path.value();
124    if value.is_empty() {
125        return Err(Error::new_spanned(
126            path,
127            "config_override path must not be empty",
128        ));
129    }
130
131    if value.split('.').any(str::is_empty) {
132        return Err(Error::new_spanned(
133            path,
134            "config_override path must not contain empty segments",
135        ));
136    }
137
138    Ok(())
139}
140
141fn option_inner(ty: &Type) -> Option<&Type> {
142    let Type::Path(type_path) = ty else {
143        return None;
144    };
145    let segment = type_path.path.segments.last()?;
146    if segment.ident != "Option" {
147        return None;
148    }
149    let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
150        return None;
151    };
152    let mut args = arguments.args.iter();
153    let Some(GenericArgument::Type(inner)) = args.next() else {
154        return None;
155    };
156    if args.next().is_some() {
157        return None;
158    }
159    Some(inner)
160}
161
162// ---------------------------------------------------------------------------
163// ConfigSchema derive
164// ---------------------------------------------------------------------------
165
166fn expand_config_schema(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
167    let name = input.ident;
168    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170    let fields = match input.data {
171        Data::Struct(data) => match data.fields {
172            Fields::Named(fields) => fields.named,
173            _ => {
174                return Err(Error::new_spanned(
175                    &name,
176                    "ConfigSchema only supports structs with named fields",
177                ));
178            }
179        },
180        _ => {
181            return Err(Error::new_spanned(
182                &name,
183                "ConfigSchema only supports structs",
184            ));
185        }
186    };
187
188    // 1. Look for a field annotated with #[config_schema(include)].
189    let mut include_field: Option<syn::Ident> = None;
190    for field in &fields {
191        if has_config_schema_include_attr(&field.attrs) {
192            let ident = field.ident.clone().ok_or_else(|| {
193                Error::new_spanned(&field.ty, "config_schema(include) must be on a named field")
194            })?;
195            include_field = Some(ident);
196            break;
197        }
198    }
199
200    // 2. Fall back to a field named `include` whose type is Vec<PathBuf>.
201    if include_field.is_none() {
202        for field in &fields {
203            let ident = field.ident.as_ref().ok_or_else(|| {
204                Error::new_spanned(&field.ty, "ConfigSchema requires named fields")
205            })?;
206            if ident == "include" && is_vec_path_buf(&field.ty) {
207                include_field = Some(ident.clone());
208                break;
209            }
210        }
211    }
212
213    let include_ident = include_field.ok_or_else(|| {
214        Error::new_spanned(
215            &name,
216            "ConfigSchema requires a field for include paths. \
217             Annotate one with #[config_schema(include)] or name it `include: Vec<PathBuf>`.",
218        )
219    })?;
220
221    Ok(quote! {
222        impl #impl_generics ::rust_config_tree::config::ConfigSchema for #name #ty_generics #where_clause {
223            fn include_paths(
224                layer: &<Self as ::confique::Config>::Layer,
225            ) -> ::std::vec::Vec<::std::path::PathBuf> {
226                layer.#include_ident.clone().unwrap_or_default()
227            }
228        }
229    })
230}
231
232/// Checks whether a field carries `#[config_schema(include)]`.
233fn has_config_schema_include_attr(attrs: &[Attribute]) -> bool {
234    for attr in attrs {
235        if !attr.path().is_ident("config_schema") {
236            continue;
237        }
238        // Accept `#[config_schema(include)]`.
239        if attr
240            .parse_args::<syn::Ident>()
241            .is_ok_and(|ident| ident == "include")
242        {
243            return true;
244        }
245    }
246    false
247}
248
249/// Returns `true` when the type is `Vec<PathBuf>` (with any leading
250/// `std::` / `::std::` qualifiers).
251fn is_vec_path_buf(ty: &Type) -> bool {
252    let Type::Path(type_path) = ty else {
253        return false;
254    };
255    let segment = match type_path.path.segments.last() {
256        Some(s) => s,
257        None => return false,
258    };
259    if segment.ident != "Vec" {
260        return false;
261    }
262    let PathArguments::AngleBracketed(args) = &segment.arguments else {
263        return false;
264    };
265    let Some(GenericArgument::Type(inner)) = args.args.first() else {
266        return false;
267    };
268    is_path_buf(inner)
269}
270
271/// Returns `true` when the type resolves to `PathBuf` (possibly qualified).
272fn is_path_buf(ty: &Type) -> bool {
273    let Type::Path(type_path) = ty else {
274        return false;
275    };
276    let segment = match type_path.path.segments.last() {
277        Some(s) => s,
278        None => return false,
279    };
280    segment.ident == "PathBuf"
281}