service_skeleton_derive/
lib.rs

1use darling::{ast, util::Flag, util::SpannedValue, FromDeriveInput, FromField};
2use heck::AsShoutySnekCase;
3use proc_macro2::TokenStream;
4use quote::{quote, quote_spanned, ToTokens};
5use syn::{parse_macro_input, spanned::Spanned, ExprPath, Ident, Type};
6
7#[proc_macro_derive(ServiceConfig, attributes(config))]
8pub fn derive_service_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9	let input = parse_macro_input!(input);
10
11	let receiver = match ServiceConfigReceiver::from_derive_input(&input) {
12		Ok(r) => r,
13		Err(e) => return e.write_errors().into(),
14	};
15
16	quote!(#receiver).into()
17}
18
19#[derive(Debug, FromDeriveInput)]
20#[darling(supports(struct_named))]
21struct ServiceConfigReceiver {
22	ident: Ident,
23	generics: syn::Generics,
24	data: ast::Data<(), SpannedValue<ServiceConfigField>>,
25}
26
27impl ToTokens for ServiceConfigReceiver {
28	fn to_tokens(&self, tokens: &mut TokenStream) {
29		let struct_name = &self.ident;
30		let (imp, ty, wher) = self.generics.split_for_impl();
31
32		let mut fields: Vec<TokenStream> = Vec::new();
33		let mut purges: Vec<TokenStream> = Vec::new();
34
35		#[allow(clippy::expect_used)] // Ensured by darling(supports(struct_named))
36		for f in self
37			.data
38			.as_ref()
39			.take_struct()
40			.expect("data to be a struct")
41			.fields
42		{
43			fields.push(f.field_init());
44
45			purges.push(f.purge_sensitive());
46		}
47
48		tokens.extend(quote! {
49			impl #imp ServiceConfig for #struct_name #ty #wher {
50				fn from_env_vars(prefix: &str, vars: impl Iterator<Item = (String, String)>) -> Result<#struct_name, service_skeleton::Error> {
51					let prefix = ::service_skeleton::heck::AsShoutySnekCase(prefix).to_string();
52					let var_map: ::std::collections::HashMap<String, String> = vars.collect();
53					let mut key_map: ::std::collections::HashMap<::service_skeleton::config::Key, ::secrecy::SecretString> = ::std::collections::HashMap::new();
54
55					let cfg = #struct_name {
56						#(#fields)*
57					};
58
59					#(#purges)*
60
61					Ok(cfg)
62				}
63			}
64		});
65	}
66}
67
68#[derive(Debug, FromField)]
69#[darling(attributes(config))]
70struct ServiceConfigField {
71	ident: Option<Ident>,
72	ty: Type,
73
74	default_value: Option<SpannedValue<String>>,
75	value_parser: Option<SpannedValue<ExprPath>>,
76	encrypted: Flag,
77	sensitive: Flag,
78	key_file_field: Option<SpannedValue<String>>,
79}
80
81impl ServiceConfigField {
82	fn field_init(&self) -> TokenStream {
83		let field_name = self.field_name();
84		let fmt_str = Self::env_var_format_string(&field_name.to_string());
85		let value_parser = self.value_parser();
86		let default_value = self.default_value();
87		let fetch_value = self.fetch_value();
88
89		if self.is_optional() {
90			quote_spanned! { self.ident.span()=>
91				#field_name: ::service_skeleton::config::determine_optional_value(
92					&format!(#fmt_str, prefix),
93					#value_parser,
94					#fetch_value,
95					#default_value
96				)?,
97			}
98		} else {
99			quote_spanned! { self.ident.span()=>
100				#field_name: ::service_skeleton::config::determine_value(
101					&format!(#fmt_str, prefix),
102					#value_parser,
103					#fetch_value,
104					#default_value
105				)?,
106			}
107		}
108	}
109
110	fn fetch_value(&self) -> TokenStream {
111		let field_var_fmt_str = Self::env_var_format_string(&self.field_name().to_string());
112
113		if self.encrypted.is_present() {
114			if let Some(ref key_file_field) = self.key_file_field {
115				let key_var_fmt_str = Self::env_var_format_string(key_file_field);
116
117				quote_spanned! { self.ident.span()=>
118					::service_skeleton::config::fetch_encrypted_field(&var_map, &mut key_map, &format!(#field_var_fmt_str, prefix), &::service_skeleton::config::Key::File(format!(#key_var_fmt_str, prefix)))?.as_deref()
119				}
120			} else {
121				quote_spanned! { self.encrypted.span()=>
122					compile_error!("field is encrypted but no key_file was specified to decrypt");
123				}
124			}
125		} else {
126			quote_spanned! { self.ident.span()=>
127				var_map.get(&format!(#field_var_fmt_str, prefix)).map(::std::string::String::as_str)
128			}
129		}
130	}
131
132	fn purge_sensitive(&self) -> TokenStream {
133		if self.is_sensitive() {
134			let fmt_str = Self::env_var_format_string(&self.field_name().to_string());
135			quote_spanned! { self.ident.span()=>
136				::tracing::debug!("Removing sensitive env var {}", format!(#fmt_str, prefix));
137				::std::env::remove_var(&format!(#fmt_str, prefix));
138			}
139		} else {
140			quote! {}
141		}
142	}
143
144	fn field_name(&self) -> &Ident {
145		#[allow(clippy::expect_used)]
146		self.ident
147			.as_ref()
148			.expect("named field does not have a field")
149	}
150
151	fn env_var_format_string(field_name: &str) -> String {
152		format!(
153			"{{}}_{shouty_field_name}",
154			shouty_field_name = AsShoutySnekCase(field_name)
155		)
156	}
157
158	fn is_sensitive(&self) -> bool {
159		self.sensitive.is_present()
160	}
161
162	fn is_optional(&self) -> bool {
163		#[allow(clippy::wildcard_enum_match_arm)] // Yes, that's rather the point here
164		match &self.ty {
165			Type::Path(tp) if tp.qself.is_none() => {
166				let path_idents = tp.path.segments.iter().fold(String::new(), |mut s, v| {
167					s.push_str(&v.ident.to_string());
168					s.push_str("->");
169					s
170				});
171				vec![
172					"Option->",
173					"std->option->Option->",
174					"core->option->Option->",
175				]
176				.into_iter()
177				.any(|s| *s == path_idents)
178			}
179			_ => false,
180		}
181	}
182
183	/// Determine the type of the field that we will want to parse into -- essentially,
184	/// the field's specified type less any wrapping Option<>, if present.
185	fn value_type(&self) -> &Type {
186		#[allow(clippy::wildcard_enum_match_arm)] // Yes, that's rather the point here
187		match &self.ty {
188			Type::Path(tp) if tp.qself.is_none() => {
189				if self.is_optional() {
190					#[allow(clippy::unwrap_used)] // There has to be segments if we got here
191					if let syn::PathArguments::AngleBracketed(args) =
192						&tp.path.segments.iter().next_back().unwrap().arguments
193					{
194						if let Some(syn::GenericArgument::Type(t)) = &args.args.iter().next() {
195							t
196						} else {
197							&self.ty
198						}
199					} else {
200						&self.ty
201					}
202				} else {
203					&self.ty
204				}
205			}
206			_ => &self.ty,
207		}
208	}
209
210	fn value_parser(&self) -> TokenStream {
211		if let Some(value_parser) = &self.value_parser {
212			// The as_ref() turns SpannedValue<T> into something that impls ToTokens, somehow
213			let parser = value_parser.as_ref();
214			quote_spanned! { value_parser.span()=> #parser }
215		} else {
216			let value_type = self.value_type();
217			quote_spanned! { self.ident.span()=>
218				|s: &str| s.parse::<#value_type>()
219			}
220		}
221	}
222
223	fn default_value(&self) -> TokenStream {
224		if let Some(default_value) = &self.default_value {
225			// The as_ref() turns SpannedValue<T> into something that impls ToTokens, somehow
226			let default_value = default_value.as_ref();
227			quote_spanned! { default_value.span()=> Some(#default_value) }
228		} else {
229			quote_spanned! { self.ident.span()=> None }
230		}
231	}
232}
233
234// Only used in integration tests
235#[cfg(test)]
236use trybuild as _;