Skip to main content

serde_reflect_intermediate_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::{Span, TokenStream};
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input, Attribute, Data, DeriveInput, Fields, Ident, Index, Lit, Meta, NestedMeta,
7};
8
9#[derive(Debug, Default)]
10struct TypeAttribs {
11    before_patch_change: Option<Ident>,
12    after_patch_change: Option<Ident>,
13}
14
15#[derive(Debug, Default)]
16struct FieldAttribs {
17    pub ignore: bool,
18    pub indirect: bool,
19}
20
21#[proc_macro_derive(ReflectIntermediate, attributes(reflect_intermediate))]
22pub fn derive_reflect_intermediate(input: TokenStream) -> TokenStream {
23    let ast = parse_macro_input!(input as DeriveInput);
24    let attribs = parse_type_attribs(&ast.attrs);
25    let before_patch_change = match attribs.before_patch_change {
26        Some(name) => {
27            quote! {
28                fn before_patch_change(&mut self) {
29                    self.#name();
30                }
31            }
32        }
33        None => Default::default(),
34    };
35    let after_patch_change = match attribs.after_patch_change {
36        Some(name) => {
37            quote! {
38                fn after_patch_change(&mut self) {
39                    self.#name();
40                }
41            }
42        }
43        None => Default::default(),
44    };
45    let name = &ast.ident;
46    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
47    match ast.data {
48        Data::Struct(data) => match data.fields {
49            Fields::Named(fields) => {
50                let fields = fields.named.iter().filter_map(|field| {
51                    let attribs = parse_field_attribs(&field.attrs);
52                    if attribs.ignore {
53                        return None;
54                    }
55                    let name = field.ident.as_ref().unwrap();
56                    let key = name.to_string();
57                    if attribs.indirect {
58                        Some(quote! {
59                            #key => {
60                                if let Ok(serialized) = serde_intermediate::to_intermediate(&self.#name) {
61                                    if let Ok(Some(patched)) = change.patch(&serialized) {
62                                        if let Ok(deserialized) = serde_intermediate::from_intermediate(&patched) {
63                                            self.#name = deserialized;
64                                        }
65                                    }
66                                }
67                            }
68                        })
69                    } else {
70                        Some(quote! {
71                            #key => {
72                                self.#name.patch_change(change);
73                            }
74                        })
75                    }
76                }).collect::<Vec<_>>();
77                quote! {
78                    impl #impl_generics serde_reflect_intermediate::ReflectIntermediate for #name #ty_generics #where_clause {
79                        fn patch_change(&mut self, change: &Change) {
80                            self.before_patch_change();
81                            match change {
82                                Change::Changed(v) => {
83                                    if let Ok(v) = serde_intermediate::from_intermediate(v) {
84                                        *self = v;
85                                    }
86                                }
87                                Change::PartialStruct(v) => {
88                                    for (name, change) in v {
89                                        match name.as_str() {
90                                            #( #fields )*
91                                            _ => {}
92                                        }
93                                    }
94                                }
95                                _ => {}
96                            }
97                            self.after_patch_change();
98                        }
99
100                        #before_patch_change
101
102                        #after_patch_change
103                    }
104                }.into()
105            }
106            Fields::Unnamed(fields) => {
107                let fields = fields.unnamed.iter().enumerate().filter_map(|(index,field)| {
108                    let attribs = parse_field_attribs(&field.attrs);
109                    if attribs.ignore {
110                        return None;
111                    }
112                    let tuple_index = Index::from(index);
113                    if attribs.indirect {
114                        Some(quote! {
115                            #index => {
116                                if let Ok(serialized) = serde_intermediate::to_intermediate(&self.#tuple_index) {
117                                    if let Ok(Some(patched)) = change.patch(&serialized) {
118                                        if let Ok(deserialized) = serde_intermediate::from_intermediate(&patched) {
119                                            self.#tuple_index = deserialized;
120                                        }
121                                    }
122                                }
123                            }
124                        })
125                    } else {
126                        Some(quote! {
127                            #index => {
128                                self.#tuple_index.patch_change(change);
129                            }
130                        })
131                    }
132                }).collect::<Vec<_>>();
133                quote! {
134                    impl #impl_generics serde_reflect_intermediate::ReflectIntermediate for #name #ty_generics #where_clause {
135                        fn patch_change(&mut self, change: &Change) {
136                            self.before_patch_change();
137                            match change {
138                                Change::Changed(v) => {
139                                    if let Ok(v) = serde_intermediate::from_intermediate(v) {
140                                        *self = v;
141                                    }
142                                }
143                                Change::PartialSeq(v) => {
144                                    for (index, change) in v {
145                                        match *index {
146                                            #( #fields )*
147                                            _ => {}
148                                        }
149                                    }
150                                }
151                                _ => {}
152                            }
153                            self.after_patch_change();
154                        }
155
156                        #before_patch_change
157
158                        #after_patch_change
159                    }
160                }.into()
161            }
162            Fields::Unit => quote! {
163                impl #impl_generics serde_reflect_intermediate::ReflectIntermediate for #name #ty_generics #where_clause {}
164            }
165            .into(),
166        },
167        Data::Enum(data) => {
168            let new_type_variants = data.variants.iter().filter_map(|variant| {
169                let name = &variant.ident;
170                if let Fields::Unnamed(_) = &variant.fields {
171                    if variant.fields.len() == 1 {
172                        let field = variant.fields.iter().next().unwrap();
173                        let attribs = parse_field_attribs(&field.attrs);
174                        if attribs.ignore {
175                            return None;
176                        }
177                        if attribs.indirect {
178                            Some(quote! {
179                                Self::#name(content) => {
180                                    if let Ok(serialized) = serde_intermediate::to_intermediate(content) {
181                                        if let Ok(Some(patched)) = change.patch(&serialized) {
182                                            if let Ok(deserialized) = serde_intermediate::from_intermediate(&patched) {
183                                                *content = deserialized;
184                                            }
185                                        }
186                                    }
187                                }
188                            })
189                        } else {
190                            Some(quote! {
191                                Self::#name(content) => {
192                                    content.patch_change(change);
193                                }
194                            })
195                        }
196                    } else {
197                        None
198                    }
199                } else {
200                    None
201                }
202            }).collect::<Vec<_>>();
203            let struct_variants = data.variants.iter().filter_map(|variant| {
204                let attribs = parse_field_attribs(&variant.attrs);
205                if attribs.ignore {
206                    return None;
207                }
208                let name = &variant.ident;
209                if let Fields::Named(fields) = &variant.fields {
210                    let field_names = fields
211                        .named
212                        .iter()
213                        .filter_map(|field| {
214                            let attribs = parse_field_attribs(&field.attrs);
215                            if attribs.ignore {
216                                return None;
217                            }
218                            Some(field.ident.as_ref().unwrap())
219                        })
220                        .collect::<Vec<_>>();
221                    let fields = fields.named.iter().filter_map(|field| {
222                        let attribs = parse_field_attribs(&field.attrs);
223                        if attribs.ignore {
224                            return None;
225                        }
226                        let name = field.ident.as_ref().unwrap();
227                        let key = name.to_string();
228                        if attribs.indirect {
229                            Some(quote! {
230                                #key => {
231                                    if let Ok(serialized) = serde_intermediate::to_intermediate(#name) {
232                                        if let Ok(Some(patched)) = change.patch(&serialized) {
233                                            if let Ok(deserialized) = serde_intermediate::from_intermediate(&patched) {
234                                                *#name = deserialized;
235                                            }
236                                        }
237                                    }
238                                }
239                            })
240                        } else {
241                            Some(quote! {
242                                #key => {
243                                    #name.patch_change(change);
244                                }
245                            })
246                        }
247                    }).collect::<Vec<_>>();
248                    Some(quote! {
249                        Self::#name { #( #field_names , )* .. } => {
250                            for (name, change) in v {
251                                match name.as_str() {
252                                    #( #fields )*
253                                    _ => {}
254                                }
255                            }
256                        }
257                    })
258                } else {
259                    None
260                }
261            }).collect::<Vec<_>>();
262            quote! {
263                impl #impl_generics serde_reflect_intermediate::ReflectIntermediate for #name #ty_generics #where_clause {
264                    fn patch_change(&mut self, change: &Change) {
265                        self.before_patch_change();
266                        match change {
267                            Change::Changed(v) => {
268                                if let Ok(v) = serde_intermediate::from_intermediate(v) {
269                                    *self = v;
270                                }
271                            }
272                            Change::PartialChange(change) => {
273                                match self {
274                                    #( #new_type_variants )*
275                                    _ => {}
276                                }
277                            }
278                            Change::PartialStruct(v) => {
279                                match self {
280                                    #( #struct_variants )*
281                                    _ => {}
282                                }
283                            }
284                            _ => {}
285                        }
286                        self.after_patch_change();
287                    }
288
289                    #before_patch_change
290
291                    #after_patch_change
292                }
293            }.into()
294        }
295        _ => panic!("ReflectIntermediate can be derived only for structs and enums"),
296    }
297}
298
299fn parse_type_attribs(attrs: &[Attribute]) -> TypeAttribs {
300    let mut result = TypeAttribs::default();
301    for attrib in attrs {
302        match attrib.parse_meta() {
303            Err(error) => panic!(
304                "Could not parse attribute `{}`: {:?}",
305                attrib.to_token_stream(),
306                error
307            ),
308            Ok(Meta::List(meta)) => {
309                if meta.path.is_ident("reflect_intermediate") {
310                    for meta in meta.nested {
311                        if let NestedMeta::Meta(Meta::NameValue(meta)) = &meta {
312                            if meta.path.is_ident("before_patch_change") {
313                                if let Lit::Str(value) = &meta.lit {
314                                    result.before_patch_change =
315                                        Some(Ident::new(&value.value(), Span::call_site().into()));
316                                }
317                            } else if meta.path.is_ident("after_patch_change") {
318                                if let Lit::Str(value) = &meta.lit {
319                                    result.after_patch_change =
320                                        Some(Ident::new(&value.value(), Span::call_site().into()));
321                                }
322                            }
323                        }
324                    }
325                }
326            }
327            _ => {}
328        }
329    }
330    result
331}
332
333fn parse_field_attribs(attrs: &[Attribute]) -> FieldAttribs {
334    let mut result = FieldAttribs::default();
335    for attrib in attrs {
336        match attrib.parse_meta() {
337            Err(error) => panic!(
338                "Could not parse attribute `{}`: {:?}",
339                attrib.to_token_stream(),
340                error
341            ),
342            Ok(Meta::List(meta)) => {
343                if meta.path.is_ident("reflect_intermediate") {
344                    for meta in meta.nested {
345                        if let NestedMeta::Meta(Meta::Path(path)) = &meta {
346                            if path.is_ident("ignore") {
347                                result.ignore = true;
348                            } else if path.is_ident("indirect") {
349                                result.indirect = true;
350                            }
351                        }
352                    }
353                }
354            }
355            _ => {}
356        }
357    }
358    result
359}