rust_config_tree_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 Attribute, Data, DeriveInput, Error, Fields, GenericArgument, LitStr, PathArguments, Type,
5 parse_macro_input,
6};
7
8#[proc_macro_derive(ConfigOverrides, attributes(config_override))]
9pub fn derive_config_overrides(input: TokenStream) -> TokenStream {
10 match expand_config_overrides(parse_macro_input!(input as DeriveInput)) {
11 Ok(tokens) => tokens.into(),
12 Err(err) => err.to_compile_error().into(),
13 }
14}
15
16#[proc_macro_derive(ConfigSchema, attributes(config_schema))]
17pub fn derive_config_schema(input: TokenStream) -> TokenStream {
18 match expand_config_schema(parse_macro_input!(input as DeriveInput)) {
19 Ok(tokens) => tokens.into(),
20 Err(err) => err.to_compile_error().into(),
21 }
22}
23
24fn expand_config_overrides(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
25 let name = input.ident;
26 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
27 let fields = match input.data {
28 Data::Struct(data) => match data.fields {
29 Fields::Named(fields) => fields.named,
30 _ => {
31 return Err(Error::new_spanned(
32 name,
33 "ConfigOverrides only supports structs with named fields",
34 ));
35 }
36 },
37 _ => {
38 return Err(Error::new_spanned(
39 name,
40 "ConfigOverrides only supports structs",
41 ));
42 }
43 };
44
45 let mut inserts = Vec::new();
46 for field in fields {
47 let Some(path) = override_path(&field.attrs)? else {
48 continue;
49 };
50 let ident = field.ident.ok_or_else(|| {
51 Error::new_spanned(&field.ty, "config_override must be used on a named field")
52 })?;
53
54 if option_inner(&field.ty).is_some() {
55 inserts.push(quote! {
56 if let Some(value) = &self.#ident {
57 provider.insert(#path, value)?;
58 }
59 });
60 } else {
61 inserts.push(quote! {
62 provider.insert(#path, &self.#ident)?;
63 });
64 }
65 }
66
67 Ok(quote! {
68 impl #impl_generics ::rust_config_tree::cli::ConfigOverrides for #name #ty_generics #where_clause {
69 fn config_overrides(
70 &self,
71 ) -> ::rust_config_tree::config::ConfigResult<::rust_config_tree::cli::ConfigOverrideProvider> {
72 let mut provider = ::rust_config_tree::cli::ConfigOverrideProvider::new();
73 #(#inserts)*
74 Ok(provider)
75 }
76 }
77 })
78}
79
80fn override_path(attrs: &[Attribute]) -> syn::Result<Option<LitStr>> {
81 let mut path = None;
82
83 for attr in attrs {
84 if !attr.path().is_ident("config_override") {
85 continue;
86 }
87
88 if path.is_some() {
89 return Err(Error::new_spanned(
90 attr,
91 "config_override cannot be repeated on the same field",
92 ));
93 }
94
95 let parsed_path = parse_override_path(attr)?;
96 validate_path(&parsed_path)?;
97 path = Some(parsed_path);
98 }
99
100 Ok(path)
101}
102
103fn parse_override_path(attr: &Attribute) -> syn::Result<LitStr> {
104 if let Ok(path) = attr.parse_args::<LitStr>() {
105 return Ok(path);
106 }
107
108 let mut path = None;
109 attr.parse_nested_meta(|meta| {
110 if !meta.path.is_ident("path") {
111 return Err(meta.error("config_override only supports the path argument"));
112 }
113 let value = meta.value()?;
114 let lit = value.parse::<LitStr>()?;
115 path = Some(lit);
116 Ok(())
117 })?;
118
119 path.ok_or_else(|| Error::new_spanned(attr, "config_override requires a path argument"))
120}
121
122fn validate_path(path: &LitStr) -> syn::Result<()> {
123 let value = path.value();
124 if value.is_empty() {
125 return Err(Error::new_spanned(
126 path,
127 "config_override path must not be empty",
128 ));
129 }
130
131 if value.split('.').any(str::is_empty) {
132 return Err(Error::new_spanned(
133 path,
134 "config_override path must not contain empty segments",
135 ));
136 }
137
138 Ok(())
139}
140
141fn option_inner(ty: &Type) -> Option<&Type> {
142 let Type::Path(type_path) = ty else {
143 return None;
144 };
145 let segment = type_path.path.segments.last()?;
146 if segment.ident != "Option" {
147 return None;
148 }
149 let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
150 return None;
151 };
152 let mut args = arguments.args.iter();
153 let Some(GenericArgument::Type(inner)) = args.next() else {
154 return None;
155 };
156 if args.next().is_some() {
157 return None;
158 }
159 Some(inner)
160}
161
162fn expand_config_schema(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
167 let name = input.ident;
168 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
169
170 let fields = match input.data {
171 Data::Struct(data) => match data.fields {
172 Fields::Named(fields) => fields.named,
173 _ => {
174 return Err(Error::new_spanned(
175 &name,
176 "ConfigSchema only supports structs with named fields",
177 ));
178 }
179 },
180 _ => {
181 return Err(Error::new_spanned(
182 &name,
183 "ConfigSchema only supports structs",
184 ));
185 }
186 };
187
188 let mut include_field: Option<syn::Ident> = None;
190 for field in &fields {
191 if has_config_schema_include_attr(&field.attrs) {
192 let ident = field.ident.clone().ok_or_else(|| {
193 Error::new_spanned(&field.ty, "config_schema(include) must be on a named field")
194 })?;
195 include_field = Some(ident);
196 break;
197 }
198 }
199
200 if include_field.is_none() {
202 for field in &fields {
203 let ident = field.ident.as_ref().ok_or_else(|| {
204 Error::new_spanned(&field.ty, "ConfigSchema requires named fields")
205 })?;
206 if ident == "include" && is_vec_path_buf(&field.ty) {
207 include_field = Some(ident.clone());
208 break;
209 }
210 }
211 }
212
213 let include_ident = include_field.ok_or_else(|| {
214 Error::new_spanned(
215 &name,
216 "ConfigSchema requires a field for include paths. \
217 Annotate one with #[config_schema(include)] or name it `include: Vec<PathBuf>`.",
218 )
219 })?;
220
221 Ok(quote! {
222 impl #impl_generics ::rust_config_tree::config::ConfigSchema for #name #ty_generics #where_clause {
223 fn include_paths(
224 layer: &<Self as ::confique::Config>::Layer,
225 ) -> ::std::vec::Vec<::std::path::PathBuf> {
226 layer.#include_ident.clone().unwrap_or_default()
227 }
228 }
229 })
230}
231
232fn has_config_schema_include_attr(attrs: &[Attribute]) -> bool {
234 for attr in attrs {
235 if !attr.path().is_ident("config_schema") {
236 continue;
237 }
238 if attr
240 .parse_args::<syn::Ident>()
241 .is_ok_and(|ident| ident == "include")
242 {
243 return true;
244 }
245 }
246 false
247}
248
249fn is_vec_path_buf(ty: &Type) -> bool {
252 let Type::Path(type_path) = ty else {
253 return false;
254 };
255 let segment = match type_path.path.segments.last() {
256 Some(s) => s,
257 None => return false,
258 };
259 if segment.ident != "Vec" {
260 return false;
261 }
262 let PathArguments::AngleBracketed(args) = &segment.arguments else {
263 return false;
264 };
265 let Some(GenericArgument::Type(inner)) = args.args.first() else {
266 return false;
267 };
268 is_path_buf(inner)
269}
270
271fn is_path_buf(ty: &Type) -> bool {
273 let Type::Path(type_path) = ty else {
274 return false;
275 };
276 let segment = match type_path.path.segments.last() {
277 Some(s) => s,
278 None => return false,
279 };
280 segment.ident == "PathBuf"
281}