Skip to main content

serviceconf_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
6
7mod attrs;
8
9use attrs::FieldAttrs;
10
11/// Extract the inner type `T` from `Option<T>`, returning the original type if not an Option.
12///
13/// This helper is used to generate correct error messages and deserializer calls
14/// for optional fields, where the inner type needs to be referenced separately.
15fn extract_option_inner_type(ty: &Type) -> &Type {
16    if let Type::Path(type_path) = ty {
17        if let Some(seg) = type_path.path.segments.last() {
18            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
19                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
20                    return inner;
21                }
22            }
23        }
24    }
25    ty
26}
27
28/// `ServiceConf` derive macro
29///
30/// Automatically implements the `from_env()` method on structs for loading configuration
31/// from environment variables.
32///
33/// # Supported Attributes
34///
35/// ## Struct-level Attributes
36///
37/// ### `#[conf(prefix = "PREFIX_")]`
38/// Add a prefix to all environment variable names in the struct.
39///
40/// ```no_run
41/// use serviceconf::ServiceConf;
42///
43/// #[derive(ServiceConf)]
44/// #[conf(prefix = "MYAPP_")]
45/// struct Config {
46///     pub api_key: String,  // Reads from MYAPP_API_KEY
47///     pub port: u16,        // Reads from MYAPP_PORT
48/// }
49/// ```
50///
51/// ## Field-level Attributes
52///
53/// ### `#[conf(name = "CUSTOM_NAME")]`
54/// Override the default environment variable name for a specific field.
55///
56/// ```no_run
57/// use serviceconf::ServiceConf;
58///
59/// #[derive(ServiceConf)]
60/// struct Config {
61///     #[conf(name = "DATABASE_URL")]
62///     pub db_connection: String,  // Reads from DATABASE_URL
63/// }
64/// ```
65///
66/// ### `#[conf(default)]`
67/// Use `Default::default()` when the environment variable is not set.
68///
69/// ```no_run
70/// use serviceconf::ServiceConf;
71///
72/// #[derive(ServiceConf)]
73/// struct Config {
74///     #[conf(default)]
75///     pub port: u16,  // Uses 0 if PORT not set
76/// }
77/// ```
78///
79/// ### `#[conf(default = value)]`
80/// Use an explicit default value when the environment variable is not set.
81///
82/// ```no_run
83/// use serviceconf::ServiceConf;
84///
85/// #[derive(ServiceConf)]
86/// struct Config {
87///     #[conf(default = 8080)]
88///     pub port: u16,  // Uses 8080 if PORT not set
89/// }
90/// ```
91///
92/// ### `#[conf(from_file)]`
93/// Support loading from file-based secrets (Kubernetes/Docker Secrets).
94/// Reads from both `VAR_NAME` and `VAR_NAME_FILE` environment variables.
95///
96/// ```no_run
97/// use serviceconf::ServiceConf;
98///
99/// #[derive(ServiceConf)]
100/// struct Config {
101///     #[conf(from_file)]
102///     pub api_key: String,  // Reads from API_KEY or API_KEY_FILE
103/// }
104/// ```
105///
106/// ### `#[conf(deserializer = "function")]`
107/// Use a custom deserializer function for complex types.
108///
109/// The function signature must be: `fn(&str) -> Result<T, impl std::fmt::Display>`
110///
111/// Can be combined with `default` to provide a fallback value, or used with `Option<T>`
112/// to make the field optional:
113///
114/// ```no_run
115/// use serviceconf::ServiceConf;
116/// use std::time::Duration;
117///
118/// fn parse_duration_secs(s: &str) -> Result<Duration, String> {
119///     s.parse::<u64>()
120///         .map(Duration::from_secs)
121///         .map_err(|e| format!("Failed to parse: {}", e))
122/// }
123///
124/// #[derive(ServiceConf)]
125/// struct Config {
126///     // Required field with custom deserializer
127///     #[conf(deserializer = "parse_duration_secs")]
128///     pub timeout: Duration,
129///
130///     // With default value (uses default when env var is not set)
131///     #[conf(deserializer = "parse_duration_secs", default = Duration::from_secs(60))]
132///     pub retry_interval: Duration,
133///
134///     // With Option<T> (None when env var is not set)
135///     #[conf(deserializer = "parse_duration_secs")]
136///     pub max_timeout: Option<Duration>,
137/// }
138/// ```
139///
140/// # Examples
141///
142/// **Basic usage:**
143/// ```no_run
144/// use serviceconf::ServiceConf;
145///
146/// #[derive(ServiceConf)]
147/// struct Config {
148///     pub api_key: String,
149///
150///     #[conf(default = 8080)]
151///     pub port: u16,
152/// }
153///
154/// fn main() -> anyhow::Result<()> {
155///     let config = Config::from_env()?;
156///     Ok(())
157/// }
158/// ```
159///
160/// **With prefix and file-based secrets:**
161/// ```no_run
162/// use serviceconf::ServiceConf;
163///
164/// #[derive(ServiceConf)]
165/// #[conf(prefix = "APP_")]
166/// struct Config {
167///     #[conf(from_file)]
168///     pub database_password: String,  // Reads from APP_DATABASE_PASSWORD or APP_DATABASE_PASSWORD_FILE
169///
170///     #[conf(default = 3000)]
171///     pub port: u16,  // Reads from APP_PORT, defaults to 3000
172/// }
173/// ```
174///
175/// **With custom deserializer and default:**
176/// ```
177/// use serviceconf::ServiceConf;
178/// use std::time::Duration;
179///
180/// fn parse_duration_secs(s: &str) -> Result<Duration, String> {
181///     s.parse::<u64>()
182///         .map(Duration::from_secs)
183///         .map_err(|e| e.to_string())
184/// }
185///
186/// fn parse_comma_list(s: &str) -> Result<Vec<String>, String> {
187///     Ok(s.split(',').map(|s| s.trim().to_string()).collect())
188/// }
189///
190/// #[derive(ServiceConf)]
191/// struct Config {
192///     // Custom deserializer with explicit default value
193///     #[conf(deserializer = "parse_duration_secs", default = Duration::from_secs(30))]
194///     pub timeout: Duration,
195///
196///     // Custom deserializer with Default::default()
197///     #[conf(deserializer = "parse_comma_list", default)]
198///     pub allowed_hosts: Vec<String>,
199/// }
200///
201/// // Uses default values when environment variables are not set
202/// std::env::remove_var("TIMEOUT");
203/// std::env::remove_var("ALLOWED_HOSTS");
204/// let config = Config::from_env().unwrap();
205/// assert_eq!(config.timeout, Duration::from_secs(30));
206/// assert_eq!(config.allowed_hosts, Vec::<String>::new());
207///
208/// // Override with environment variables
209/// std::env::set_var("TIMEOUT", "60");
210/// std::env::set_var("ALLOWED_HOSTS", "localhost, example.com");
211/// let config = Config::from_env().unwrap();
212/// assert_eq!(config.timeout, Duration::from_secs(60));
213/// assert_eq!(config.allowed_hosts, vec!["localhost", "example.com"]);
214/// # std::env::remove_var("TIMEOUT");
215/// # std::env::remove_var("ALLOWED_HOSTS");
216/// ```
217///
218/// For complete documentation and more examples, see the [`serviceconf`](https://docs.rs/serviceconf) crate.
219#[proc_macro_derive(ServiceConf, attributes(conf))]
220pub fn derive_serviceconf(input: TokenStream) -> TokenStream {
221    let input = parse_macro_input!(input as DeriveInput);
222
223    // Struct name
224    let struct_name = &input.ident;
225
226    // Parse struct-level attributes (prefix)
227    let mut prefix = String::new();
228
229    for attr in &input.attrs {
230        if !attr.path().is_ident("conf") {
231            continue;
232        }
233
234        let _ = attr.parse_nested_meta(|meta| {
235            if meta.path.is_ident("prefix") {
236                let value = meta.value()?;
237                let lit: syn::Lit = value.parse()?;
238                if let syn::Lit::Str(s) = lit {
239                    prefix = s.value();
240                }
241                return Ok(());
242            }
243
244            Err(meta.error("unsupported struct-level conf attribute"))
245        });
246    }
247
248    // Extract fields
249    let fields = match &input.data {
250        Data::Struct(data) => match &data.fields {
251            Fields::Named(fields) => &fields.named,
252            _ => {
253                return syn::Error::new_spanned(
254                    &input,
255                    "ServiceConf only supports structs with named fields",
256                )
257                .to_compile_error()
258                .into();
259            }
260        },
261        _ => {
262            return syn::Error::new_spanned(&input, "ServiceConf only supports structs")
263                .to_compile_error()
264                .into();
265        }
266    };
267
268    // Validate field attributes before code generation to avoid malformed error tokens
269    for field in fields.iter() {
270        let field_type = &field.ty;
271        let attrs = FieldAttrs::from_field(field);
272
273        // Check if type is Option<T>
274        let is_option = if let syn::Type::Path(type_path) = field_type {
275            type_path
276                .path
277                .segments
278                .last()
279                .map(|seg| seg.ident == "Option")
280                .unwrap_or(false)
281        } else {
282            false
283        };
284
285        // Validate invalid attribute combinations
286        if is_option && attrs.default.is_some() {
287            return syn::Error::new_spanned(
288                field,
289                "Option<T> fields cannot have default attribute (they default to None automatically)",
290            )
291            .to_compile_error()
292            .into();
293        }
294    }
295
296    // Generate deserialization code for each field
297    let field_initializers = fields.iter().map(|field| {
298        let field_name = field.ident.as_ref().unwrap();
299        let field_type = &field.ty;
300
301        // Parse attributes
302        let attrs = FieldAttrs::from_field(field);
303
304        // Check if type is Option<T>
305        let is_option = if let syn::Type::Path(type_path) = field_type {
306            type_path.path.segments.last()
307                .map(|seg| seg.ident == "Option")
308                .unwrap_or(false)
309        } else {
310            false
311        };
312
313        // Determine environment variable name
314        let base_name = attrs.name.unwrap_or_else(|| {
315            // Convert field name to UPPER_SNAKE_CASE
316            field_name.to_string().to_uppercase()
317        });
318
319        // Apply prefix
320        let env_var_name = format!("{}{}", prefix, base_name);
321
322        let load_from_file = attrs.from_file;
323        let deserializer_fn = attrs.deserializer;
324
325        // Generate deserialization expression
326        let deserialize_expr = if is_option && deserializer_fn.is_none() {
327            // Option<T> without deserializer
328            let inner_type = extract_option_inner_type(field_type);
329
330            quote! {
331                ::serviceconf::de::deserialize_optional::<#inner_type>(
332                    #env_var_name,
333                    #load_from_file
334                )?
335            }
336        } else if let Some(func_path) = deserializer_fn {
337            // Use custom deserializer function
338            let func: proc_macro2::TokenStream = func_path.parse().unwrap();
339
340            if is_option {
341                // Option<T> with deserializer
342                let inner_type = extract_option_inner_type(field_type);
343
344                quote! {
345                    match ::serviceconf::de::get_env_value(#env_var_name, #load_from_file) {
346                        Ok(__value) => Some(#func(&__value).map_err(|e| ::serviceconf::ServiceConfError::parse_error::<#inner_type>(#env_var_name, e))?),
347                        Err(::serviceconf::ServiceConfError::Missing { .. }) => None,
348                        Err(e) => return Err(e.into()),
349                    }
350                }
351            } else {
352                // Non-Option with deserializer
353                match attrs.default {
354                    Some(Some(default_value)) => {
355                        // Explicit default value with deserializer
356                        quote! {
357                            match ::serviceconf::de::get_env_value(#env_var_name, #load_from_file) {
358                                Ok(__value) => #func(&__value).map_err(|e| ::serviceconf::ServiceConfError::parse_error::<#field_type>(#env_var_name, e))?,
359                                Err(::serviceconf::ServiceConfError::Missing { .. }) => #default_value,
360                                Err(e) => return Err(e.into()),
361                            }
362                        }
363                    }
364                    Some(None) => {
365                        // Use Default::default() with deserializer
366                        quote! {
367                            match ::serviceconf::de::get_env_value(#env_var_name, #load_from_file) {
368                                Ok(__value) => #func(&__value).map_err(|e| ::serviceconf::ServiceConfError::parse_error::<#field_type>(#env_var_name, e))?,
369                                Err(::serviceconf::ServiceConfError::Missing { .. }) => Default::default(),
370                                Err(e) => return Err(e.into()),
371                            }
372                        }
373                    }
374                    None => {
375                        // Required field with deserializer
376                        quote! {
377                            {
378                                let __value = ::serviceconf::de::get_env_value(#env_var_name, #load_from_file)?;
379                                #func(&__value).map_err(|e| ::serviceconf::ServiceConfError::parse_error::<#field_type>(#env_var_name, e))?
380                            }
381                        }
382                    }
383                }
384            }
385        } else {
386            // Use FromStr deserialization (default)
387            match attrs.default {
388                Some(Some(default_value)) => {
389                    // Explicit default value
390                    quote! {
391                        ::serviceconf::de::deserialize_with_default::<#field_type>(
392                            #env_var_name,
393                            #load_from_file,
394                            #default_value
395                        )?
396                    }
397                }
398                Some(None) => {
399                    // Use Default::default()
400                    quote! {
401                        ::serviceconf::de::deserialize_with_default::<#field_type>(
402                            #env_var_name,
403                            #load_from_file,
404                            Default::default()
405                        )?
406                    }
407                }
408                None => {
409                    // Required field
410                    quote! {
411                        ::serviceconf::de::deserialize_required::<#field_type>(
412                            #env_var_name,
413                            #load_from_file
414                        )?
415                    }
416                }
417            }
418        };
419
420        quote! {
421            #field_name: #deserialize_expr
422        }
423    });
424
425    // Generate from_env() method
426    let expanded = quote! {
427        impl #struct_name {
428            /// Load configuration from environment variables
429            ///
430            /// # Errors
431            ///
432            /// - Required environment variables are not set
433            /// - Environment variable values cannot be parsed into target types
434            /// - File-based configuration fails to read files
435            pub fn from_env() -> ::serviceconf::anyhow::Result<Self> {
436                Ok(Self {
437                    #(#field_initializers),*
438                })
439            }
440        }
441    };
442
443    TokenStream::from(expanded)
444}