Skip to main content

sanitization_derive/
lib.rs

1#![deny(unsafe_code)]
2#![deny(unsafe_op_in_unsafe_fn)]
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{format_ident, quote};
7use syn::{
8    parse_macro_input, parse_quote, Attribute, Data, DataEnum, DataStruct, DeriveInput, Error,
9    Fields, Generics, LitStr, Path, Result, WherePredicate,
10};
11
12/// Derive `sanitization::SecureSanitize` for structs and enums.
13///
14/// Every non-skipped field must implement `SecureSanitize`. Use
15/// `#[sanitization(skip)]` only for fields that are intentionally non-secret or
16/// cleared elsewhere.
17///
18/// # Enums
19///
20/// For enums, generated code can only sanitize the currently active variant.
21/// It cannot safely reach bytes left behind by previously active variants after
22/// a variant transition. Use `sanitization::secure_replace` before replacement,
23/// derive `SecureSanitizeOnDrop` when drop-before-assignment semantics are
24/// wanted, or prefer struct wrappers for high-assurance state machines.
25///
26/// When the `strict-enum-derive` feature is enabled on this derive crate,
27/// enum derives require:
28///
29/// ```ignore
30/// #[sanitization(enum_inactive_variant_bytes = "acknowledged")]
31/// ```
32#[proc_macro_derive(SecureSanitize, attributes(sanitization))]
33pub fn derive_secure_sanitize(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as DeriveInput);
35    expand_secure_sanitize(&input)
36        .unwrap_or_else(Error::into_compile_error)
37        .into()
38}
39
40/// Derive `Drop` by calling `sanitization::SecureSanitize::secure_sanitize`.
41///
42/// # Generics
43///
44/// For structs with type parameters that hold sanitizable data, the parameter
45/// must carry the `SecureSanitize` bound at the type declaration:
46///
47/// ```ignore
48/// use sanitization::SecureSanitize;
49///
50/// #[derive(SecureSanitize, SecureSanitizeOnDrop)]
51/// struct Wrapper<T: SecureSanitize> {
52///     inner: T,
53/// }
54/// ```
55///
56/// This is a Rust `Drop` restriction: the generated `Drop` impl cannot add a
57/// stricter `T: SecureSanitize` bound than the struct declaration itself.
58#[proc_macro_derive(SecureSanitizeOnDrop, attributes(sanitization))]
59pub fn derive_secure_sanitize_on_drop(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as DeriveInput);
61    expand_secure_sanitize_on_drop(&input)
62        .unwrap_or_else(Error::into_compile_error)
63        .into()
64}
65
66/// Derive `sanitization::ct::ConstantTimeEq` for structs.
67///
68/// The generated implementation compares each non-skipped field through that
69/// field's own `ConstantTimeEq` implementation and combines the hidden
70/// `sanitization::ct::Choice` bits. It never compares raw struct bytes, so
71/// padding and representation details are not read.
72///
73/// Enums and unions are rejected. For enums, inactive variant bytes cannot be
74/// reached safely and comparing only the active variant can hide residual
75/// secret bytes from previous variants.
76#[proc_macro_derive(ConstantTimeEq, attributes(sanitization))]
77pub fn derive_constant_time_eq(input: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79    expand_constant_time_eq(&input)
80        .unwrap_or_else(Error::into_compile_error)
81        .into()
82}
83
84/// Derive `sanitization::ct::ConditionallySelectable` for structs.
85///
86/// The generated implementation selects every field through that field's own
87/// `ConditionallySelectable` implementation. `#[sanitization(skip)]` is
88/// intentionally rejected for this derive because the output must be a complete
89/// selection between `left` and `right`.
90///
91/// Enums and unions are rejected. Field-wise struct selection avoids raw
92/// representation reads and does not inspect padding bytes.
93#[proc_macro_derive(ConditionallySelectable, attributes(sanitization))]
94pub fn derive_conditionally_selectable(input: TokenStream) -> TokenStream {
95    let input = parse_macro_input!(input as DeriveInput);
96    expand_conditionally_selectable(&input)
97        .unwrap_or_else(Error::into_compile_error)
98        .into()
99}
100
101#[derive(Default)]
102struct ContainerOptions {
103    crate_path: Option<Path>,
104    bound_override: Option<Vec<WherePredicate>>,
105    enum_inactive_variant_bytes_acknowledged: bool,
106}
107
108#[derive(Default)]
109struct FieldOptions {
110    skip: bool,
111    bound_override: Option<Vec<WherePredicate>>,
112}
113
114fn expand_secure_sanitize(input: &DeriveInput) -> Result<TokenStream2> {
115    let options = parse_container_options(&input.attrs)?;
116    let crate_path = crate_path(&options);
117    let body = match &input.data {
118        Data::Struct(data) => expand_struct_body(data, &crate_path)?,
119        Data::Enum(data) => {
120            validate_enum_options(input, &options)?;
121            expand_enum_body(data, &crate_path)?
122        }
123        Data::Union(_) => {
124            return Err(Error::new_spanned(
125                input,
126                "SecureSanitize cannot be derived for unions; implement it manually using documented unsafe code for the active field",
127            ))
128        }
129    };
130    let generics = add_sanitize_bounds(input.generics.clone(), &input.data, &crate_path, &options)?;
131    let name = &input.ident;
132    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
133
134    Ok(quote! {
135        impl #impl_generics #crate_path::SecureSanitize for #name #type_generics #where_clause {
136            #[inline]
137            fn secure_sanitize(&mut self) {
138                #body
139            }
140        }
141    })
142}
143
144fn validate_enum_options(input: &DeriveInput, options: &ContainerOptions) -> Result<()> {
145    if cfg!(feature = "strict-enum-derive") && !options.enum_inactive_variant_bytes_acknowledged {
146        return Err(Error::new_spanned(
147            input,
148            "SecureSanitize enum derives are rejected by the strict-enum-derive feature unless #[sanitization(enum_inactive_variant_bytes = \"acknowledged\")] is present; derived enum sanitization only clears the active variant",
149        ));
150    }
151
152    Ok(())
153}
154
155fn expand_secure_sanitize_on_drop(input: &DeriveInput) -> Result<TokenStream2> {
156    let options = parse_container_options(&input.attrs)?;
157    let crate_path = crate_path(&options);
158
159    if matches!(input.data, Data::Union(_)) {
160        return Err(Error::new_spanned(
161            input,
162            "SecureSanitizeOnDrop cannot be derived for unions",
163        ));
164    }
165
166    let name = &input.ident;
167    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
168
169    Ok(quote! {
170        impl #impl_generics Drop for #name #type_generics #where_clause {
171            #[inline]
172            fn drop(&mut self) {
173                #crate_path::SecureSanitize::secure_sanitize(self);
174            }
175        }
176    })
177}
178
179fn expand_constant_time_eq(input: &DeriveInput) -> Result<TokenStream2> {
180    let options = parse_container_options(&input.attrs)?;
181    let crate_path = crate_path(&options);
182    let body = match &input.data {
183        Data::Struct(data) => expand_ct_eq_struct_body(data, &crate_path)?,
184        Data::Enum(_) => {
185            return Err(Error::new_spanned(
186                input,
187                "ConstantTimeEq cannot be derived for enums; compare explicit struct wrappers or implement the active-variant semantics manually",
188            ))
189        }
190        Data::Union(_) => {
191            return Err(Error::new_spanned(
192                input,
193                "ConstantTimeEq cannot be derived for unions; implement it manually using documented unsafe code for the active field",
194            ))
195        }
196    };
197    let trait_path: TokenStream2 = quote!(#crate_path::ct::ConstantTimeEq);
198    let generics = add_trait_bounds(
199        input.generics.clone(),
200        &input.data,
201        &trait_path,
202        &options,
203        SkipPolicy::Allow,
204    )?;
205    let name = &input.ident;
206    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
207
208    Ok(quote! {
209        impl #impl_generics #crate_path::ct::ConstantTimeEq for #name #type_generics #where_clause {
210            #[inline]
211            fn ct_eq(&self, other: &Self) -> #crate_path::ct::Choice {
212                #body
213            }
214        }
215    })
216}
217
218fn expand_conditionally_selectable(input: &DeriveInput) -> Result<TokenStream2> {
219    let options = parse_container_options(&input.attrs)?;
220    let crate_path = crate_path(&options);
221    let body = match &input.data {
222        Data::Struct(data) => expand_ct_select_struct_body(data, &crate_path)?,
223        Data::Enum(_) => {
224            return Err(Error::new_spanned(
225                input,
226                "ConditionallySelectable cannot be derived for enums; select explicit struct wrappers or implement the active-variant semantics manually",
227            ))
228        }
229        Data::Union(_) => {
230            return Err(Error::new_spanned(
231                input,
232                "ConditionallySelectable cannot be derived for unions; implement it manually using documented unsafe code for the active field",
233            ))
234        }
235    };
236    let trait_path: TokenStream2 = quote!(#crate_path::ct::ConditionallySelectable);
237    let generics = add_trait_bounds(
238        input.generics.clone(),
239        &input.data,
240        &trait_path,
241        &options,
242        SkipPolicy::Reject,
243    )?;
244    let name = &input.ident;
245    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
246
247    Ok(quote! {
248        impl #impl_generics #crate_path::ct::ConditionallySelectable for #name #type_generics #where_clause {
249            #[inline]
250            fn conditional_select(
251                left: &Self,
252                right: &Self,
253                choice: #crate_path::ct::Choice,
254            ) -> Self {
255                #body
256            }
257        }
258    })
259}
260
261fn crate_path(options: &ContainerOptions) -> Path {
262    options
263        .crate_path
264        .clone()
265        .unwrap_or_else(|| parse_quote!(::sanitization))
266}
267
268fn add_sanitize_bounds(
269    mut generics: Generics,
270    data: &Data,
271    crate_path: &Path,
272    options: &ContainerOptions,
273) -> Result<Generics> {
274    let where_clause = generics.make_where_clause();
275
276    if let Some(bounds) = &options.bound_override {
277        where_clause.predicates.extend(bounds.iter().cloned());
278        return Ok(generics);
279    }
280
281    for field in sanitized_fields(data)? {
282        let field_options = parse_field_options(&field.attrs)?;
283        if field_options.skip {
284            continue;
285        }
286
287        if let Some(bounds) = field_options.bound_override {
288            where_clause.predicates.extend(bounds);
289        } else {
290            let ty = &field.ty;
291            where_clause
292                .predicates
293                .push(parse_quote!(#ty: #crate_path::SecureSanitize));
294        }
295    }
296
297    Ok(generics)
298}
299
300#[derive(Clone, Copy)]
301enum SkipPolicy {
302    Allow,
303    Reject,
304}
305
306fn add_trait_bounds(
307    mut generics: Generics,
308    data: &Data,
309    trait_path: &TokenStream2,
310    options: &ContainerOptions,
311    skip_policy: SkipPolicy,
312) -> Result<Generics> {
313    let where_clause = generics.make_where_clause();
314
315    if let Some(bounds) = &options.bound_override {
316        where_clause.predicates.extend(bounds.iter().cloned());
317        return Ok(generics);
318    }
319
320    for field in sanitized_fields(data)? {
321        let field_options = parse_field_options(&field.attrs)?;
322        if field_options.skip {
323            if matches!(skip_policy, SkipPolicy::Reject) {
324                return Err(Error::new_spanned(
325                    field,
326                    "#[sanitization(skip)] is not supported for this derive because every output field must be constructed",
327                ));
328            }
329            continue;
330        }
331
332        if let Some(bounds) = field_options.bound_override {
333            where_clause.predicates.extend(bounds);
334        } else {
335            let ty = &field.ty;
336            where_clause.predicates.push(parse_quote!(#ty: #trait_path));
337        }
338    }
339
340    Ok(generics)
341}
342
343fn sanitized_fields(data: &Data) -> Result<Vec<&syn::Field>> {
344    let mut fields = Vec::new();
345    match data {
346        Data::Struct(data) => fields.extend(data.fields.iter()),
347        Data::Enum(data) => {
348            for variant in &data.variants {
349                fields.extend(variant.fields.iter());
350            }
351        }
352        Data::Union(_) => {}
353    }
354    Ok(fields)
355}
356
357fn expand_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
358    let calls = field_calls_for_struct(&data.fields, crate_path)?;
359    Ok(quote!(#(#calls)*))
360}
361
362fn expand_ct_eq_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
363    let mut calls = Vec::new();
364
365    for (index, field) in data.fields.iter().enumerate() {
366        if parse_field_options(&field.attrs)?.skip {
367            continue;
368        }
369
370        let (left, right) = match &field.ident {
371            Some(ident) => (quote!(&self.#ident), quote!(&other.#ident)),
372            None => {
373                let index = syn::Index::from(index);
374                (quote!(&self.#index), quote!(&other.#index))
375            }
376        };
377        calls.push(quote! {
378            result = result & #crate_path::ct::ConstantTimeEq::ct_eq(#left, #right);
379        });
380    }
381
382    Ok(quote! {
383        let mut result = #crate_path::ct::Choice::TRUE;
384        #(#calls)*
385        result
386    })
387}
388
389fn expand_ct_select_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
390    match &data.fields {
391        Fields::Named(fields) => {
392            let mut selected = Vec::new();
393            for field in &fields.named {
394                if parse_field_options(&field.attrs)?.skip {
395                    return Err(Error::new_spanned(
396                        field,
397                        "#[sanitization(skip)] is not supported for ConditionallySelectable derives",
398                    ));
399                }
400                let ident = field.ident.as_ref().expect("named field");
401                selected.push(quote! {
402                    #ident: #crate_path::ct::ConditionallySelectable::conditional_select(
403                        &left.#ident,
404                        &right.#ident,
405                        choice,
406                    )
407                });
408            }
409            Ok(quote!(Self { #(#selected),* }))
410        }
411        Fields::Unnamed(fields) => {
412            let mut selected = Vec::new();
413            for (index, field) in fields.unnamed.iter().enumerate() {
414                if parse_field_options(&field.attrs)?.skip {
415                    return Err(Error::new_spanned(
416                        field,
417                        "#[sanitization(skip)] is not supported for ConditionallySelectable derives",
418                    ));
419                }
420                let index = syn::Index::from(index);
421                selected.push(quote! {
422                    #crate_path::ct::ConditionallySelectable::conditional_select(
423                        &left.#index,
424                        &right.#index,
425                        choice,
426                    )
427                });
428            }
429            Ok(quote!(Self(#(#selected),*)))
430        }
431        Fields::Unit => Ok(quote!(Self)),
432    }
433}
434
435fn field_calls_for_struct(fields: &Fields, crate_path: &Path) -> Result<Vec<TokenStream2>> {
436    let mut calls = Vec::new();
437
438    for (index, field) in fields.iter().enumerate() {
439        if parse_field_options(&field.attrs)?.skip {
440            continue;
441        }
442
443        let access = match &field.ident {
444            Some(ident) => quote!(&mut self.#ident),
445            None => {
446                let index = syn::Index::from(index);
447                quote!(&mut self.#index)
448            }
449        };
450        calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#access);));
451    }
452
453    Ok(calls)
454}
455
456fn expand_enum_body(data: &DataEnum, crate_path: &Path) -> Result<TokenStream2> {
457    let mut arms = Vec::new();
458
459    for variant in &data.variants {
460        let variant_ident = &variant.ident;
461        let (pattern, calls) = match &variant.fields {
462            Fields::Named(fields) => {
463                let mut bindings = Vec::new();
464                let mut calls = Vec::new();
465                for field in &fields.named {
466                    let ident = field.ident.as_ref().expect("named field");
467                    if parse_field_options(&field.attrs)?.skip {
468                        continue;
469                    }
470                    bindings.push(quote!(#ident));
471                    calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#ident);));
472                }
473
474                let pattern = if bindings.is_empty() {
475                    quote!(Self::#variant_ident { .. })
476                } else {
477                    quote!(Self::#variant_ident { #(#bindings),*, .. })
478                };
479                (pattern, calls)
480            }
481            Fields::Unnamed(fields) => {
482                let mut pattern_fields = Vec::new();
483                let mut calls = Vec::new();
484                for (index, field) in fields.unnamed.iter().enumerate() {
485                    if parse_field_options(&field.attrs)?.skip {
486                        pattern_fields.push(quote!(_));
487                    } else {
488                        let binding = format_ident!("field_{index}");
489                        pattern_fields.push(quote!(#binding));
490                        calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#binding);));
491                    }
492                }
493                (quote!(Self::#variant_ident(#(#pattern_fields),*)), calls)
494            }
495            Fields::Unit => (quote!(Self::#variant_ident), Vec::new()),
496        };
497
498        arms.push(quote!(#pattern => { #(#calls)* }));
499    }
500
501    Ok(quote! {
502        match self {
503            #(#arms),*
504        }
505    })
506}
507
508fn parse_container_options(attrs: &[Attribute]) -> Result<ContainerOptions> {
509    let mut options = ContainerOptions::default();
510
511    for attr in attrs
512        .iter()
513        .filter(|attr| attr.path().is_ident("sanitization"))
514    {
515        attr.parse_nested_meta(|meta| {
516            if meta.path.is_ident("crate") {
517                let value = meta.value()?;
518                let literal: LitStr = value.parse()?;
519                options.crate_path = Some(literal.parse()?);
520                Ok(())
521            } else if meta.path.is_ident("bound") {
522                let value = meta.value()?;
523                let literal: LitStr = value.parse()?;
524                options.bound_override = Some(parse_bounds(&literal)?);
525                Ok(())
526            } else if meta.path.is_ident("enum_inactive_variant_bytes") {
527                let value = meta.value()?;
528                let literal: LitStr = value.parse()?;
529                if literal.value() == "acknowledged" {
530                    options.enum_inactive_variant_bytes_acknowledged = true;
531                    Ok(())
532                } else {
533                    Err(meta.error("enum_inactive_variant_bytes must be exactly \"acknowledged\""))
534                }
535            } else {
536                Err(meta.error("unsupported sanitization container attribute"))
537            }
538        })?;
539    }
540
541    Ok(options)
542}
543
544fn parse_field_options(attrs: &[Attribute]) -> Result<FieldOptions> {
545    let mut options = FieldOptions::default();
546
547    for attr in attrs
548        .iter()
549        .filter(|attr| attr.path().is_ident("sanitization"))
550    {
551        attr.parse_nested_meta(|meta| {
552            if meta.path.is_ident("skip") {
553                options.skip = true;
554                Ok(())
555            } else if meta.path.is_ident("bound") {
556                let value = meta.value()?;
557                let literal: LitStr = value.parse()?;
558                options.bound_override = Some(parse_bounds(&literal)?);
559                Ok(())
560            } else {
561                Err(meta.error("unsupported sanitization field attribute"))
562            }
563        })?;
564    }
565
566    Ok(options)
567}
568
569fn parse_bounds(literal: &LitStr) -> Result<Vec<WherePredicate>> {
570    let text = literal.value();
571    if text.trim().is_empty() {
572        return Ok(Vec::new());
573    }
574
575    let where_clause: syn::WhereClause = syn::parse_str(&format!("where {text}"))
576        .map_err(|error| Error::new(literal.span(), error))?;
577    Ok(where_clause.predicates.into_iter().collect())
578}