struct_patch_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro2::{Ident, Span, TokenStream};
3use quote::{quote, ToTokens};
4use std::str::FromStr;
5use syn::{
6    meta::ParseNestedMeta, parenthesized, spanned::Spanned, DeriveInput, Error, LitStr, Result,
7    Type,
8};
9
10const PATCH: &str = "patch";
11const NAME: &str = "name";
12const ATTRIBUTE: &str = "attribute";
13const SKIP: &str = "skip";
14const ADDABLE: &str = "addable";
15const ADD: &str = "add";
16
17#[proc_macro_derive(Patch, attributes(patch))]
18pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
19    Patch::from_ast(syn::parse_macro_input!(item as syn::DeriveInput))
20        .unwrap()
21        .to_token_stream()
22        .unwrap()
23        .into()
24}
25
26struct Patch {
27    visibility: syn::Visibility,
28    struct_name: Ident,
29    patch_struct_name: Ident,
30    generics: syn::Generics,
31    attributes: Vec<TokenStream>,
32    fields: Vec<Field>,
33}
34
35#[cfg(feature = "op")]
36enum Addable {
37    Disable,
38    AddTriat,
39    #[cfg(feature = "op")]
40    AddFn(Ident),
41}
42
43struct Field {
44    ident: Option<Ident>,
45    ty: Type,
46    attributes: Vec<TokenStream>,
47    retyped: bool,
48    #[cfg(feature = "op")]
49    addable: Addable,
50}
51
52impl Patch {
53    /// Generate the token stream for the patch struct and it resulting implementations
54    pub fn to_token_stream(&self) -> Result<TokenStream> {
55        let Patch {
56            visibility,
57            struct_name,
58            patch_struct_name: name,
59            generics,
60            attributes,
61            fields,
62        } = self;
63
64        let patch_struct_fields = fields
65            .iter()
66            .map(|f| f.to_token_stream())
67            .collect::<Result<Vec<_>>>()?;
68        let field_names = fields.iter().map(|f| f.ident.as_ref()).collect::<Vec<_>>();
69
70        let renamed_field_names = fields
71            .iter()
72            .filter(|f| f.retyped)
73            .map(|f| f.ident.as_ref())
74            .collect::<Vec<_>>();
75
76        let original_field_names = fields
77            .iter()
78            .filter(|f| !f.retyped)
79            .map(|f| f.ident.as_ref())
80            .collect::<Vec<_>>();
81
82        let mapped_attributes = attributes
83            .iter()
84            .map(|a| {
85                quote! {
86                    #[#a]
87                }
88            })
89            .collect::<Vec<_>>();
90
91        let patch_struct = quote! {
92            #(#mapped_attributes)*
93            #visibility struct #name #generics {
94                #(#patch_struct_fields)*
95            }
96        };
97        let where_clause = &generics.where_clause;
98
99        #[cfg(feature = "status")]
100        let patch_status_impl = quote!(
101            impl #generics struct_patch::traits::PatchStatus for #name #generics #where_clause {
102                fn is_empty(&self) -> bool {
103                    #(
104                        if self.#field_names.is_some() {
105                            return false
106                        }
107                    )*
108                    true
109                }
110            }
111        );
112        #[cfg(not(feature = "status"))]
113        let patch_status_impl = quote!();
114
115        #[cfg(feature = "merge")]
116        let patch_merge_impl = quote!(
117            impl #generics struct_patch::traits::Merge for #name #generics #where_clause {
118                fn merge(self, other: Self) -> Self {
119                    Self {
120                        #(
121                            #renamed_field_names: match (self.#renamed_field_names, other.#renamed_field_names) {
122                                (Some(a), Some(b)) => Some(a.merge(b)),
123                                (Some(a), None) => Some(a),
124                                (None, Some(b)) => Some(b),
125                                (None, None) => None,
126                            },
127                        )*
128                        #(
129                            #original_field_names: other.#original_field_names.or(self.#original_field_names),
130                        )*
131                    }
132                }
133            }
134        );
135        #[cfg(not(feature = "merge"))]
136        let patch_merge_impl = quote!();
137
138        #[cfg(feature = "op")]
139        let addable_handles = fields
140            .iter()
141            .map(|f| {
142                match &f.addable {
143                    Addable::AddTriat => quote!(
144                        Some(a + b)
145                    ),
146                    Addable::AddFn(f) => {
147                        quote!(
148                            Some(#f(a, b))
149                        )
150                    } ,
151                    Addable::Disable => quote!(
152                        panic!("There are conflict patches, please use `#[patch(addable)]` if you want to add these values.")
153                    )
154                }
155            })
156            .collect::<Vec<_>>();
157
158        #[cfg(all(feature = "op", not(feature = "merge")))]
159        let op_impl = quote! {
160            impl #generics core::ops::Shl<#name #generics> for #struct_name #generics #where_clause {
161                type Output = Self;
162
163                fn shl(mut self, rhs: #name #generics) -> Self {
164                    struct_patch::traits::Patch::apply(&mut self, rhs);
165                    self
166                }
167            }
168
169            impl #generics core::ops::Add<Self> for #name #generics #where_clause {
170                type Output = Self;
171
172                fn add(mut self, rhs: Self) -> Self {
173                    Self {
174                        #(
175                            #renamed_field_names: match (self.#renamed_field_names, rhs.#renamed_field_names) {
176                                (Some(a), Some(b)) => {
177                                    #addable_handles
178                                },
179                                (Some(a), None) => Some(a),
180                                (None, Some(b)) => Some(b),
181                                (None, None) => None,
182                            },
183                        )*
184                        #(
185                            #original_field_names: match (self.#original_field_names, rhs.#original_field_names) {
186                                (Some(a), Some(b)) => {
187                                    #addable_handles
188                                },
189                                (Some(a), None) => Some(a),
190                                (None, Some(b)) => Some(b),
191                                (None, None) => None,
192                            },
193                        )*
194                    }
195                }
196            }
197        };
198
199        #[cfg(feature = "merge")]
200        let op_impl = quote! {
201            impl #generics core::ops::Shl<#name #generics> for #struct_name #generics #where_clause {
202                type Output = Self;
203
204                fn shl(mut self, rhs: #name #generics) -> Self {
205                    struct_patch::traits::Patch::apply(&mut self, rhs);
206                    self
207                }
208            }
209
210            impl #generics core::ops::Shl<#name #generics> for #name #generics #where_clause {
211                type Output = Self;
212
213                fn shl(mut self, rhs: Self) -> Self {
214                    struct_patch::traits::Merge::merge(self, rhs)
215                }
216            }
217
218            impl #generics core::ops::Add<Self> for #name #generics #where_clause {
219                type Output = Self;
220
221                fn add(mut self, rhs: Self) -> Self {
222                    Self {
223                        #(
224                            #renamed_field_names: match (self.#renamed_field_names, rhs.#renamed_field_names) {
225                                (Some(a), Some(b)) => {
226                                    #addable_handles
227                                },
228                                (Some(a), None) => Some(a),
229                                (None, Some(b)) => Some(b),
230                                (None, None) => None,
231                            },
232                        )*
233                        #(
234                            #original_field_names: match (self.#original_field_names, rhs.#original_field_names) {
235                                (Some(a), Some(b)) => {
236                                    #addable_handles
237                                },
238                                (Some(a), None) => Some(a),
239                                (None, Some(b)) => Some(b),
240                                (None, None) => None,
241                            },
242                        )*
243                    }
244                }
245            }
246        };
247
248        #[cfg(not(feature = "op"))]
249        let op_impl = quote!();
250
251        let patch_impl = quote! {
252            impl #generics struct_patch::traits::Patch< #name #generics > for #struct_name #generics #where_clause  {
253                fn apply(&mut self, patch: #name #generics) {
254                    #(
255                        if let Some(v) = patch.#renamed_field_names {
256                            self.#renamed_field_names.apply(v);
257                        }
258                    )*
259                    #(
260                        if let Some(v) = patch.#original_field_names {
261                            self.#original_field_names = v;
262                        }
263                    )*
264                }
265
266                fn into_patch(self) -> #name #generics {
267                    #name {
268                        #(
269                            #renamed_field_names: Some(self.#renamed_field_names.into_patch()),
270                        )*
271                        #(
272                            #original_field_names: Some(self.#original_field_names),
273                        )*
274                    }
275                }
276
277                fn into_patch_by_diff(self, previous_struct: Self) -> #name #generics {
278                    #name {
279                        #(
280                            #renamed_field_names: if self.#renamed_field_names != previous_struct.#renamed_field_names {
281                                Some(self.#renamed_field_names.into_patch_by_diff(previous_struct.#renamed_field_names))
282                            }
283                            else {
284                                None
285                            },
286                        )*
287                        #(
288                            #original_field_names: if self.#original_field_names != previous_struct.#original_field_names {
289                                Some(self.#original_field_names)
290                            }
291                            else {
292                                None
293                            },
294                        )*
295                    }
296                }
297
298                fn new_empty_patch() -> #name #generics {
299                    #name {
300                        #(
301                            #field_names: None,
302                        )*
303                    }
304                }
305            }
306        };
307
308        Ok(quote! {
309            #patch_struct
310
311            #patch_status_impl
312
313            #patch_merge_impl
314
315            #patch_impl
316
317            #op_impl
318        })
319    }
320
321    /// Parse the patch struct
322    pub fn from_ast(
323        DeriveInput {
324            ident,
325            data,
326            generics,
327            attrs,
328            vis,
329        }: syn::DeriveInput,
330    ) -> Result<Patch> {
331        let original_fields = if let syn::Data::Struct(syn::DataStruct { fields, .. }) = data {
332            fields
333        } else {
334            return Err(syn::Error::new(
335                ident.span(),
336                "Patch derive only use for struct",
337            ));
338        };
339
340        let mut name = None;
341        let mut attributes = vec![];
342        let mut fields = vec![];
343
344        for attr in attrs {
345            if attr.path().to_string().as_str() != PATCH {
346                continue;
347            }
348
349            if let syn::Meta::List(meta) = &attr.meta {
350                if meta.tokens.is_empty() {
351                    continue;
352                }
353            }
354
355            attr.parse_nested_meta(|meta| {
356                let path = meta.path.to_string();
357                match path.as_str() {
358                    NAME => {
359                        // #[patch(name = "PatchStruct")]
360                        if let Some(lit) = get_lit_str(path, &meta)? {
361                            if name.is_some() {
362                                return Err(meta
363                                    .error("The name attribute can't be defined more than once"));
364                            }
365                            name = Some(lit.parse()?);
366                        }
367                    }
368                    ATTRIBUTE => {
369                        // #[patch(attribute(derive(Deserialize)))]
370                        // #[patch(attribute(derive(Deserialize, Debug), serde(rename = "foo"))]
371                        let content;
372                        parenthesized!(content in meta.input);
373                        let attribute: TokenStream = content.parse()?;
374                        attributes.push(attribute);
375                    }
376                    _ => {
377                        return Err(meta.error(format_args!(
378                            "unknown patch container attribute `{}`",
379                            path.replace(' ', "")
380                        )));
381                    }
382                }
383                Ok(())
384            })?;
385        }
386
387        for field in original_fields {
388            if let Some(f) = Field::from_ast(field)? {
389                fields.push(f);
390            }
391        }
392
393        Ok(Patch {
394            visibility: vis,
395            patch_struct_name: name.unwrap_or({
396                let ts = TokenStream::from_str(&format!("{}Patch", &ident,)).unwrap();
397                let lit = LitStr::new(&ts.to_string(), Span::call_site());
398                lit.parse()?
399            }),
400            struct_name: ident,
401            generics,
402            attributes,
403            fields,
404        })
405    }
406}
407
408impl Field {
409    /// Generate the token stream for the Patch struct fields
410    pub fn to_token_stream(&self) -> Result<TokenStream> {
411        let Field {
412            ident,
413            ty,
414            attributes,
415            ..
416        } = self;
417
418        let attributes = attributes
419            .iter()
420            .map(|a| {
421                quote! {
422                    #[#a]
423                }
424            })
425            .collect::<Vec<_>>();
426        match ident {
427            Some(ident) => Ok(quote! {
428                #(#attributes)*
429                pub #ident: Option<#ty>,
430            }),
431            None => Ok(quote! {
432                #(#attributes)*
433                pub Option<#ty>,
434            }),
435        }
436    }
437
438    /// Parse the patch struct field
439    pub fn from_ast(
440        syn::Field {
441            ident, ty, attrs, ..
442        }: syn::Field,
443    ) -> Result<Option<Field>> {
444        let mut attributes = vec![];
445        let mut field_type = None;
446        let mut skip = false;
447
448        #[cfg(feature = "op")]
449        let mut addable = Addable::Disable;
450
451        for attr in attrs {
452            if attr.path().to_string().as_str() != PATCH {
453                continue;
454            }
455
456            if let syn::Meta::List(meta) = &attr.meta {
457                if meta.tokens.is_empty() {
458                    continue;
459                }
460            }
461
462            attr.parse_nested_meta(|meta| {
463                let path = meta.path.to_string();
464                match path.as_str() {
465                    SKIP => {
466                        // #[patch(skip)]
467                        skip = true;
468                    }
469                    ATTRIBUTE => {
470                        // #[patch(attribute(serde(alias = "my-field")))]
471                        let content;
472                        parenthesized!(content in meta.input);
473                        let attribute: TokenStream = content.parse()?;
474                        attributes.push(attribute);
475                    }
476                    NAME => {
477                        // #[patch(name = "ItemPatch")]
478                        let expr: LitStr = meta.value()?.parse()?;
479                        field_type = Some(expr.parse()?)
480                    }
481                    #[cfg(feature = "op")]
482                    ADDABLE => {
483                        // #[patch(addable)]
484                        addable = Addable::AddTriat;
485                    }
486                    #[cfg(not(feature = "op"))]
487                    ADDABLE => {
488                        return Err(syn::Error::new(
489                            ident.span(),
490                            "`addable` needs `op` feature",
491                        ));
492                    }
493                    #[cfg(feature = "op")]
494                    ADD => {
495                        // #[patch(add=fn)]
496                        let f: Ident = meta.value()?.parse()?;
497                        addable = Addable::AddFn(f);
498                    }
499                    #[cfg(not(feature = "op"))]
500                    ADD => {
501                        return Err(syn::Error::new(ident.span(), "`add` needs `op` feature"));
502                    }
503                    _ => {
504                        return Err(meta.error(format_args!(
505                            "unknown patch field attribute `{}`",
506                            path.replace(' ', "")
507                        )));
508                    }
509                }
510                Ok(())
511            })?;
512            if skip {
513                return Ok(None);
514            }
515        }
516
517        Ok(Some(Field {
518            ident,
519            retyped: field_type.is_some(),
520            ty: field_type.unwrap_or(ty),
521            attributes,
522            #[cfg(feature = "op")]
523            addable,
524        }))
525    }
526}
527
528trait ToStr {
529    fn to_string(&self) -> String;
530}
531
532impl ToStr for syn::Path {
533    fn to_string(&self) -> String {
534        self.to_token_stream().to_string()
535    }
536}
537
538fn get_lit_str(attr_name: String, meta: &ParseNestedMeta) -> syn::Result<Option<syn::LitStr>> {
539    let expr: syn::Expr = meta.value()?.parse()?;
540    let mut value = &expr;
541    while let syn::Expr::Group(e) = value {
542        value = &e.expr;
543    }
544    if let syn::Expr::Lit(syn::ExprLit {
545        lit: syn::Lit::Str(lit),
546        ..
547    }) = value
548    {
549        let suffix = lit.suffix();
550        if !suffix.is_empty() {
551            return Err(Error::new(
552                lit.span(),
553                format!("unexpected suffix `{}` on string literal", suffix),
554            ));
555        }
556        Ok(Some(lit.clone()))
557    } else {
558        Err(Error::new(
559            expr.span(),
560            format!(
561                "expected serde {} attribute to be a string: `{} = \"...\"`",
562                attr_name, attr_name
563            ),
564        ))
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use pretty_assertions_sorted::assert_eq_sorted;
571    use syn::token::Pub;
572
573    use super::*;
574
575    #[test]
576    fn parse_patch() {
577        // Test case 1: Valid patch with attributes and fields
578        let input = quote! {
579            #[derive(Patch)]
580            #[patch(name = "MyPatch", attribute(derive(Debug, PartialEq, Clone, Serialize, Deserialize)))]
581            pub struct Item {
582                #[patch(name = "SubItemPatch")]
583                pub field1: SubItem,
584                #[patch(skip)]
585                pub field2: Option<String>,
586            }
587        };
588        let expected = Patch {
589            visibility: syn::Visibility::Public(Pub::default()),
590            struct_name: syn::Ident::new("Item", Span::call_site()),
591            patch_struct_name: syn::Ident::new("MyPatch", Span::call_site()),
592            generics: syn::Generics::default(),
593            attributes: vec![quote! { derive(Debug, PartialEq, Clone, Serialize, Deserialize) }],
594            fields: vec![Field {
595                ident: Some(syn::Ident::new("field1", Span::call_site())),
596                ty: LitStr::new("SubItemPatch", Span::call_site())
597                    .parse()
598                    .unwrap(),
599                attributes: vec![],
600                retyped: true,
601                #[cfg(feature = "op")]
602                addable: Addable::Disable,
603            }],
604        };
605        let result = Patch::from_ast(syn::parse2(input).unwrap()).unwrap();
606        assert_eq_sorted!(
607            format!("{:?}", result.to_token_stream()),
608            format!("{:?}", expected.to_token_stream())
609        );
610    }
611}