Skip to main content

secure_serialize_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::Parser, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Expr, Fields, Lit,
5    Meta, Token,
6};
7
8/// Derives `SecureSerialize` for a struct.
9///
10/// Fields marked with `#[redact]` or `#[redact(with = "...")]` will be redacted when serialized.
11///
12/// # Field attributes
13///
14/// - `#[redact]` — Redact with default `"<redacted>"`
15/// - `#[redact(with = "***")]` — Redact with custom string `"***"`
16///
17/// # Struct attributes
18///
19/// - `#[secure_serialize(debug)]` — Generate `fmt::Debug` with redacted fields (declaration order).
20/// - `#[secure_serialize(display)]` — Generate `fmt::Display` as compact redacted JSON.
21/// - `#[secure_serialize(debug, display)]` — Both.
22///
23/// # Example
24///
25/// ```ignore
26/// #[derive(SecureSerialize, Deserialize)]
27/// #[secure_serialize(debug, display)]
28/// struct Config {
29///     pub host: String,
30///     #[redact]
31///     pub api_key: String,
32///     #[redact(with = "***")]
33///     pub password: String,
34/// }
35/// ```
36#[proc_macro_derive(SecureSerialize, attributes(redact, secure_serialize))]
37pub fn derive_secure_serialize(input: TokenStream) -> TokenStream {
38    let DeriveInput {
39        ident,
40        data,
41        generics,
42        attrs,
43        ..
44    } = parse_macro_input!(input);
45
46    let (gen_debug, gen_display) = match extract_secure_serialize_options(&attrs) {
47        Ok(v) => v,
48        Err(e) => return e.to_compile_error().into(),
49    };
50
51    let fields = match data {
52        Data::Struct(s) => match s.fields {
53            Fields::Named(f) => f.named,
54            _ => {
55                return syn::Error::new_spanned(
56                    &ident,
57                    "SecureSerialize only supports structs with named fields",
58                )
59                .to_compile_error()
60                .into();
61            }
62        },
63        _ => {
64            return syn::Error::new_spanned(&ident, "SecureSerialize only supports structs")
65                .to_compile_error()
66                .into();
67        }
68    };
69
70    // Separate fields into categories
71    let mut redacted_fields: Vec<(syn::Ident, String, String)> = Vec::new(); // (ident, name, redaction_string)
72    let mut redacted_custom_fields: Vec<(
73        syn::Ident,
74        String,
75        String,
76        proc_macro2::TokenStream,
77        syn::Type,
78    )> = Vec::new(); // (ident, name, redaction_string, serialize_path, type)
79    let mut custom_serialize_fields: Vec<(syn::Ident, proc_macro2::TokenStream, syn::Type)> =
80        Vec::new(); // (ident, serialize_path, type)
81    let mut normal_field_names: Vec<syn::Ident> = Vec::new();
82
83    for field in &fields {
84        let name = field.ident.as_ref().expect("named field");
85        let name_str = name.to_string();
86        let field_type = field.ty.clone();
87
88        // Check for #[redact] or #[redact(with = "...")]
89        let (is_redacted, redaction_string) = extract_redact_attribute(&field.attrs);
90
91        // Extract custom serialize_with function path
92        let custom_serialize_path = extract_serialize_with_attribute(&field.attrs);
93
94        match (is_redacted, &custom_serialize_path) {
95            (true, Some(path)) => redacted_custom_fields.push((
96                name.clone(),
97                name_str,
98                redaction_string,
99                path.clone(),
100                field_type,
101            )),
102            (true, None) => redacted_fields.push((name.clone(), name_str, redaction_string)),
103            (false, Some(path)) => {
104                custom_serialize_fields.push((name.clone(), path.clone(), field_type))
105            }
106            (false, None) => {
107                normal_field_names.push(name.clone());
108            }
109        }
110    }
111
112    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
113
114    // Extract names and idents for code generation
115    let redacted_field_names: Vec<String> =
116        redacted_fields.iter().map(|(_, n, _)| n.clone()).collect();
117    let redacted_field_idents: Vec<syn::Ident> =
118        redacted_fields.iter().map(|(i, _, _)| i.clone()).collect();
119    let redaction_strings: Vec<proc_macro2::TokenStream> = redacted_fields
120        .iter()
121        .map(|(_, _, r)| {
122            r.parse::<proc_macro2::TokenStream>()
123                .unwrap_or_else(|_| quote! { "<redacted>" })
124        })
125        .collect();
126
127    let redacted_custom_field_names: Vec<String> = redacted_custom_fields
128        .iter()
129        .map(|(_, n, _, _, _)| n.clone())
130        .collect();
131    let redacted_custom_strings: Vec<proc_macro2::TokenStream> = redacted_custom_fields
132        .iter()
133        .map(|(_, _, r, _, _)| {
134            r.parse::<proc_macro2::TokenStream>()
135                .unwrap_or_else(|_| quote! { "<redacted>" })
136        })
137        .collect();
138
139    let custom_serialize_idents: Vec<syn::Ident> = custom_serialize_fields
140        .iter()
141        .map(|(i, _, _)| i.clone())
142        .collect();
143
144    // Helper to generate wrapper for custom serialize_with
145    let generate_wrapper = |field_ident: &syn::Ident,
146                            path: &proc_macro2::TokenStream,
147                            field_type: &syn::Type| {
148        quote! {
149            {
150                struct _Wrapper<'a>(&'a #field_type);
151                impl<'a> ::serde::Serialize for _Wrapper<'a>
152                {
153                    fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
154                    where
155                        S: ::serde::Serializer,
156                    {
157                        #path(self.0, serializer)
158                    }
159                }
160                _Wrapper(&self.#field_ident)
161            }
162        }
163    };
164
165    // Generate wrappers for custom serialize fields in impl Serialize
166    let custom_field_wrappers: Vec<proc_macro2::TokenStream> = custom_serialize_fields
167        .iter()
168        .map(|(ident, path, ty)| generate_wrapper(ident, path, ty))
169        .collect();
170
171    // Generate wrappers for redacted_custom fields in to_json_unredacted
172    let redacted_custom_json_wrappers: Vec<proc_macro2::TokenStream> = redacted_custom_fields
173        .iter()
174        .map(|(ident, _, _, path, ty)| {
175            let wrapper = generate_wrapper(ident, path, ty);
176            quote! { ::serde_json::to_value(#wrapper)? }
177        })
178        .collect();
179
180    // Generate wrappers for custom_serialize fields in to_json_unredacted
181    let custom_json_wrappers: Vec<proc_macro2::TokenStream> = custom_serialize_fields
182        .iter()
183        .map(|(ident, path, ty)| {
184            let wrapper = generate_wrapper(ident, path, ty);
185            quote! { ::serde_json::to_value(#wrapper)? }
186        })
187        .collect();
188
189    let debug_field_fragments: Vec<proc_macro2::TokenStream> = fields
190        .iter()
191        .map(|field| {
192            let name = field.ident.as_ref().expect("named field");
193            let name_literal = name.to_string();
194            let (is_redacted, redaction_string) = extract_redact_attribute(&field.attrs);
195            if is_redacted {
196                let redact_ts = redaction_string
197                    .parse::<proc_macro2::TokenStream>()
198                    .unwrap_or_else(|_| quote! { "<redacted>" });
199                quote! {
200                    .field(#name_literal, &#redact_ts)
201                }
202            } else {
203                quote! {
204                    .field(#name_literal, &self.#name)
205                }
206            }
207        })
208        .collect();
209
210    let debug_impl = if gen_debug {
211        quote! {
212            impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
213                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
214                    f.debug_struct(stringify!(#ident))
215                        #(#debug_field_fragments)*
216                        .finish()
217                }
218            }
219        }
220    } else {
221        quote! {}
222    };
223
224    let display_impl = if gen_display {
225        quote! {
226            impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
227                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
228                    match ::serde_json::to_string(self) {
229                        Ok(ref json) => f.write_str(json),
230                        Err(e) => ::std::write!(
231                            f,
232                            concat!(stringify!(#ident), "(serialization error: {})"),
233                            e
234                        ),
235                    }
236                }
237            }
238        }
239    } else {
240        quote! {}
241    };
242
243    let expanded = quote! {
244        impl #impl_generics ::serde::Serialize for #ident #ty_generics #where_clause {
245            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
246            where
247                S: ::serde::Serializer,
248            {
249                use ::serde::ser::SerializeStruct;
250                let mut s = serializer.serialize_struct(
251                    stringify!(#ident),
252                    0usize
253                    #(+ { let _ = stringify!(#redacted_field_names); 1usize })*
254                    #(+ { let _ = stringify!(#redacted_custom_field_names); 1usize })*
255                    #(+ { let _ = stringify!(#custom_serialize_idents); 1usize })*
256                    #(+ { let _ = stringify!(#normal_field_names); 1usize })*
257                )?;
258
259                // Serialize redacted fields with their redaction strings
260                #(s.serialize_field(#redacted_field_names, #redaction_strings)?;)*
261
262                // Serialize redacted fields with custom serialize using their redaction strings
263                #(s.serialize_field(#redacted_custom_field_names, #redacted_custom_strings)?;)*
264
265                // Serialize non-secret fields with custom serializers
266                #(s.serialize_field(stringify!(#custom_serialize_idents), &#custom_field_wrappers)?;)*
267
268                // Serialize normal fields directly
269                #(s.serialize_field(stringify!(#normal_field_names), &self.#normal_field_names)?;)*
270
271                s.end()
272            }
273        }
274
275        impl #impl_generics ::secure_serialize::SecureSerialize for #ident #ty_generics #where_clause {
276            fn redacted_keys() -> &'static [&'static str] {
277                &[#(#redacted_field_names,)* #(#redacted_custom_field_names,)*]
278            }
279
280            fn to_json_unredacted(&self) -> ::std::result::Result<::serde_json::Value, ::serde_json::Error> {
281                use ::serde_json::Value as JsonValue;
282                let mut result = ::serde_json::Map::new();
283
284                // Redacted fields - use to_value for proper serialization
285                #(result.insert(#redacted_field_names.to_string(), ::serde_json::to_value(&self.#redacted_field_idents)?);)*
286
287                // Redacted fields with custom serialize - use custom serializer
288                #(result.insert(#redacted_custom_field_names.to_string(), #redacted_custom_json_wrappers);)*
289
290                // Custom serialize fields (non-redacted) - use custom serializer
291                #(result.insert(stringify!(#custom_serialize_idents).to_string(), #custom_json_wrappers);)*
292
293                // Normal fields
294                #(result.insert(stringify!(#normal_field_names).to_string(), ::serde_json::to_value(&self.#normal_field_names)?);)*
295
296                Ok(JsonValue::Object(result))
297            }
298        }
299
300        #debug_impl
301        #display_impl
302    };
303
304    let tokens = expanded.into();
305    // eprintln!("GENERATED TOKENS:\n{}", tokens);
306    tokens
307}
308
309/// Parses `#[secure_serialize(debug)]`, `#[secure_serialize(display)]`, or both on the struct.
310fn extract_secure_serialize_options(attrs: &[syn::Attribute]) -> Result<(bool, bool), syn::Error> {
311    let mut gen_debug = false;
312    let mut gen_display = false;
313
314    for attr in attrs {
315        if !attr.path().is_ident("secure_serialize") {
316            continue;
317        }
318
319        match &attr.meta {
320            Meta::Path(_) => {
321                return Err(syn::Error::new_spanned(
322                    attr,
323                    "expected #[secure_serialize(debug)], #[secure_serialize(display)], or both",
324                ));
325            }
326            Meta::List(list) => {
327                if list.tokens.is_empty() {
328                    return Err(syn::Error::new_spanned(
329                        list,
330                        "expected `debug` and/or `display` inside #[secure_serialize(...)]",
331                    ));
332                }
333                let metas = Punctuated::<Meta, Token![,]>::parse_terminated
334                    .parse2(list.tokens.clone())?;
335                for meta in metas {
336                    match meta {
337                        Meta::Path(p) => {
338                            if p.is_ident("debug") {
339                                gen_debug = true;
340                            } else if p.is_ident("display") {
341                                gen_display = true;
342                            } else {
343                                return Err(syn::Error::new_spanned(
344                                    p,
345                                    "expected `debug` or `display`",
346                                ));
347                            }
348                        }
349                        other => {
350                            return Err(syn::Error::new_spanned(
351                                other,
352                                "expected `debug` or `display`",
353                            ));
354                        }
355                    }
356                }
357            }
358            Meta::NameValue(_) => {
359                return Err(syn::Error::new_spanned(
360                    attr,
361                    "invalid #[secure_serialize(...)] syntax",
362                ));
363            }
364        }
365    }
366
367    Ok((gen_debug, gen_display))
368}
369
370/// Extracts the `#[redact]` or `#[redact(with = "...")]` attribute from a field.
371/// Returns `(true, redaction_string)` if found, `(false, _)` otherwise.
372fn extract_redact_attribute(attrs: &[syn::Attribute]) -> (bool, String) {
373    for attr in attrs {
374        if !attr.path().is_ident("redact") {
375            continue;
376        }
377
378        match &attr.meta {
379            syn::Meta::Path(_) => {
380                // #[redact] with no arguments
381                return (true, "\"<redacted>\"".to_string());
382            }
383            syn::Meta::List(list) => {
384                // #[redact(...)]
385                if let Ok(Meta::NameValue(nv)) = list.parse_args::<Meta>().and_then(|m| match m {
386                    Meta::NameValue(nv) if nv.path.is_ident("with") => Ok(Meta::NameValue(nv)),
387                    _ => Err(syn::Error::new_spanned(
388                        &list,
389                        "redact attribute expects: #[redact(with = \"string\")]",
390                    )),
391                }) {
392                    if let syn::Expr::Lit(expr_lit) = &nv.value {
393                        if let syn::Lit::Str(lit_str) = &expr_lit.lit {
394                            // Return the string literal with quotes preserved
395                            return (true, format!("\"{}\"", lit_str.value()));
396                        }
397                    }
398                }
399            }
400            _ => {}
401        }
402    }
403
404    (false, String::new())
405}
406
407/// Extracts the `serialize_with` path from `#[serde(...)]` attributes.
408fn extract_serialize_with_attribute(attrs: &[syn::Attribute]) -> Option<proc_macro2::TokenStream> {
409    for attr in attrs {
410        if !attr.path().is_ident("serde") {
411            continue;
412        }
413
414        let Meta::List(list) = &attr.meta else {
415            continue;
416        };
417
418        let metas = Punctuated::<Meta, Token![,]>::parse_terminated
419            .parse2(list.tokens.clone())
420            .ok()?;
421
422        for meta in metas {
423            let Meta::NameValue(name_value) = meta else {
424                continue;
425            };
426            if !name_value.path.is_ident("serialize_with") {
427                continue;
428            }
429
430            let Expr::Lit(expr_lit) = &name_value.value else {
431                continue;
432            };
433            let Lit::Str(value) = &expr_lit.lit else {
434                continue;
435            };
436
437            return value.parse::<proc_macro2::TokenStream>().ok();
438        }
439    }
440
441    None
442}