Skip to main content

serde_extensions_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, Fields};
4
5#[proc_macro_derive(Overwrite)]
6pub fn optional_derive(input: TokenStream) -> TokenStream {
7    // Construct a representation of Rust code as a syntax tree
8    // that we can manipulate
9    let ast = syn::parse(input).unwrap();
10
11    // Build the trait implementation
12    impl_optional_macro(&ast)
13}
14
15fn impl_optional_macro(ast: &syn::DeriveInput) -> TokenStream {
16    let name = &ast.ident;
17    let optional_name = format_ident!("Optional{}", name);
18
19    let fields = match &ast.data {
20        Data::Struct(data) => match &data.fields {
21            Fields::Named(fields) => &fields.named,
22            _ => panic!("Only named fields are supported"),
23        },
24        _ => panic!("Only structs are supported"),
25    };
26
27    let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
28
29    let overwrite_code: Vec<_> = fields
30        .iter()
31        .map(|field| {
32            let field_name = &field.ident;
33            quote! {
34                if let Some(value) = optional.#field_name {
35                    // Helper trait to try overwriting if Overwrite is implemented
36                    trait MaybeOverwrite {
37                        fn maybe_overwrite<E: serde::de::Error>(&mut self, value: ::serde_extensions::serde_value::Value) -> Result<(), E>;
38                    }
39                    
40                    // Implementation for types that implement Overwrite
41                    impl<T: Overwrite> MaybeOverwrite for T {
42                        fn maybe_overwrite<E: serde::de::Error>(&mut self, value: ::serde_extensions::serde_value::Value) -> Result<(), E> {
43                            let result: Result<(), ::serde_extensions::serde_value::DeserializerError> = 
44                                self.overwrite(::serde_extensions::serde_value::ValueDeserializer::new(value));
45                            result.map_err(|e| E::custom(e.to_string()))?;
46                            Ok(())
47                        }
48                    }
49                    
50                    self.#field_name.maybe_overwrite::<D::Error>(value)?;
51                }
52            }
53        })
54        .collect();
55
56    let gen = quote! {
57        // make an optional version of this struct that stores raw serde values
58        #[derive(serde::Deserialize, Default)]
59        #[serde(default)]
60        struct #optional_name {
61            #( #field_names: Option<::serde_extensions::serde_value::Value>, )*
62        }
63
64        // parse `d` to optional struct and check for each field if it has a value
65        // if yes then overwrite `$field_name` of `#name`
66        impl Overwrite for #name {
67            /// Overwrite self with a serde object
68            fn overwrite<'de, D>(&mut self, d: D) -> Result<(), D::Error>
69            where
70                D: serde::Deserializer<'de>
71            {
72                let optional: #optional_name = serde::de::Deserialize::deserialize(d)?;
73                #( #overwrite_code )*
74                return Ok(())
75            }
76        }
77    };
78    gen.into()
79}