rust_patch_derive/
lib.rs

1use proc_macro2::{Group, Ident, Literal, Span, TokenStream, TokenTree};
2use proc_macro_error::{abort, abort_call_site, proc_macro_error};
3use quote::quote;
4use syn::{
5    parenthesized,
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    spanned::Spanned,
9    token, Attribute, Data, DataStruct, DeriveInput, Fields, Token, Type, TypePath,
10};
11
12struct PatchEqAttr {
13    _eq_token: Token![=],
14    path: TypePath,
15}
16
17impl Parse for PatchEqAttr {
18    fn parse(input: ParseStream) -> syn::Result<Self> {
19        Ok(Self {
20            _eq_token: input.parse()?,
21            path: parse_lit_str(&input.parse()?)?,
22        })
23    }
24}
25
26struct PatchParenAttr {
27    _paren_token: token::Paren,
28    content: Ident,
29}
30
31impl Parse for PatchParenAttr {
32    fn parse(input: ParseStream) -> syn::Result<Self> {
33        let content;
34        Ok(Self {
35            _paren_token: parenthesized!(content in input),
36            content: content.parse()?,
37        })
38    }
39}
40
41#[proc_macro_derive(Patch, attributes(patch))]
42#[proc_macro_error]
43pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
44    let input = parse_macro_input!(item as DeriveInput);
45
46    let ident = input.ident;
47    let Data::Struct(DataStruct { fields, ..}) = input.data else { abort_call_site!("Patch can only be derived on structs") };
48    let fields = match fields {
49        Fields::Named(f) => f
50            .named
51            .into_pairs()
52            .map(|p| p.into_value())
53            .map(|f| (TokenTree::from(f.ident.unwrap()), f.ty, f.attrs))
54            .collect::<Vec<_>>(),
55        Fields::Unnamed(f) => f
56            .unnamed
57            .into_pairs()
58            .map(|p| p.into_value())
59            .enumerate()
60            .map(|(i, f)| {
61                (
62                    TokenTree::from(Literal::u32_unsuffixed(i as u32)),
63                    f.ty,
64                    f.attrs,
65                )
66            })
67            .collect::<Vec<_>>(),
68        Fields::Unit => Vec::new(),
69    };
70
71    let mut targets = Vec::new();
72    for patch_target in get_patch_attrs(input.attrs) {
73        let span = patch_target.span();
74        let Ok(PatchEqAttr { path, ..}) = syn::parse2(patch_target) else { abort!(span, r#"Patch target must be specified in the form `#[patch = "path::to::Type"]`"#) };
75        targets.push(path);
76    }
77
78    let mut apply_sets = Vec::new();
79    for (name, ty, attrs) in fields {
80        let Type::Path(TypePath { path, .. }) = &ty else { abort!(&ty, "Failed parsing field type as type path") };
81        let Some(ident) = path.segments.first().map(|e| &e.ident) else { abort!(&ty, "Field does not contain a valid ident") };
82        let mut direct = false;
83        let mut as_option = false;
84        for attr in get_patch_attrs(attrs) {
85            let span = attr.span();
86            let content = match syn::parse2(attr) {
87                Ok(PatchParenAttr { content, .. }) => content,
88                Err(e) => abort!(span, "Failed parsing attribute: {}", e),
89            };
90            match content.to_string().as_str() {
91                "direct" => direct = true,
92                "as_option" => as_option = true,
93                a => {
94                    abort!(span, "Unknown attribute `{}`", a)
95                }
96            }
97        }
98        if direct && as_option {
99            abort!(&ty, "Only one of `#[patch(direct)]` or `#[patch(as_option)]` may be specified for given field");
100        }
101        if as_option {
102            apply_sets.push(quote! {
103                if self.#name.is_some() {
104                    target.#name = self.#name;
105                }
106            })
107        } else if &ident.to_string() == "Option" && !direct {
108            apply_sets.push(quote! {
109                if let Some(val) = self.#name {
110                    target.#name = val;
111                }
112            });
113        } else {
114            apply_sets.push(quote! {
115                target.#name = self.#name;
116            });
117        }
118    }
119
120    let apply_content = quote! {
121        #(
122            #apply_sets
123        )*
124    };
125
126    let output = quote! {
127        #(
128            impl ::rust_patch::Patch<#targets> for #ident {
129                fn apply(self, mut target: #targets) -> #targets {
130                    { #apply_content }
131                    target
132                }
133            }
134        )*
135    };
136
137    proc_macro::TokenStream::from(output)
138}
139
140fn get_patch_attrs(attrs: Vec<Attribute>) -> Vec<TokenStream> {
141    let mut result = Vec::new();
142    for Attribute { path, tokens, .. } in attrs {
143        if path
144            .segments
145            .first()
146            .map(|e| e.ident.to_string())
147            .as_deref()
148            == Some("patch")
149        {
150            result.push(tokens);
151        }
152    }
153    result
154}
155
156// Taken from https://github.com/serde-rs/serde/blob/master/serde_derive/src/internals
157fn parse_lit_str<T>(s: &syn::LitStr) -> syn::parse::Result<T>
158where
159    T: Parse,
160{
161    let tokens = spanned_tokens(s)?;
162    syn::parse2(tokens)
163}
164
165fn spanned_tokens(s: &syn::LitStr) -> syn::parse::Result<TokenStream> {
166    let stream = syn::parse_str(&s.value())?;
167    Ok(respan(stream, s.span()))
168}
169
170fn respan(stream: TokenStream, span: Span) -> TokenStream {
171    stream
172        .into_iter()
173        .map(|token| respan_token(token, span))
174        .collect()
175}
176
177fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
178    if let TokenTree::Group(g) = &mut token {
179        *g = Group::new(g.delimiter(), respan(g.stream(), span));
180    }
181    token.set_span(span);
182    token
183}