1use proc_macro2::{Group, Ident, Literal, Span, TokenStream, TokenTree};
2use proc_macro_error::{abort, abort_call_site, proc_macro_error};
3use quote::quote;
4use syn::{
5 parenthesized,
6 parse::{Parse, ParseStream},
7 parse_macro_input,
8 spanned::Spanned,
9 token, Attribute, Data, DataStruct, DeriveInput, Fields, Token, Type, TypePath,
10};
11
12struct PatchEqAttr {
13 _eq_token: Token![=],
14 path: TypePath,
15}
16
17impl Parse for PatchEqAttr {
18 fn parse(input: ParseStream) -> syn::Result<Self> {
19 Ok(Self {
20 _eq_token: input.parse()?,
21 path: parse_lit_str(&input.parse()?)?,
22 })
23 }
24}
25
26struct PatchParenAttr {
27 _paren_token: token::Paren,
28 content: Ident,
29}
30
31impl Parse for PatchParenAttr {
32 fn parse(input: ParseStream) -> syn::Result<Self> {
33 let content;
34 Ok(Self {
35 _paren_token: parenthesized!(content in input),
36 content: content.parse()?,
37 })
38 }
39}
40
41#[proc_macro_derive(Patch, attributes(patch))]
42#[proc_macro_error]
43pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
44 let input = parse_macro_input!(item as DeriveInput);
45
46 let ident = input.ident;
47 let Data::Struct(DataStruct { fields, ..}) = input.data else { abort_call_site!("Patch can only be derived on structs") };
48 let fields = match fields {
49 Fields::Named(f) => f
50 .named
51 .into_pairs()
52 .map(|p| p.into_value())
53 .map(|f| (TokenTree::from(f.ident.unwrap()), f.ty, f.attrs))
54 .collect::<Vec<_>>(),
55 Fields::Unnamed(f) => f
56 .unnamed
57 .into_pairs()
58 .map(|p| p.into_value())
59 .enumerate()
60 .map(|(i, f)| {
61 (
62 TokenTree::from(Literal::u32_unsuffixed(i as u32)),
63 f.ty,
64 f.attrs,
65 )
66 })
67 .collect::<Vec<_>>(),
68 Fields::Unit => Vec::new(),
69 };
70
71 let mut targets = Vec::new();
72 for patch_target in get_patch_attrs(input.attrs) {
73 let span = patch_target.span();
74 let Ok(PatchEqAttr { path, ..}) = syn::parse2(patch_target) else { abort!(span, r#"Patch target must be specified in the form `#[patch = "path::to::Type"]`"#) };
75 targets.push(path);
76 }
77
78 let mut apply_sets = Vec::new();
79 for (name, ty, attrs) in fields {
80 let Type::Path(TypePath { path, .. }) = &ty else { abort!(&ty, "Failed parsing field type as type path") };
81 let Some(ident) = path.segments.first().map(|e| &e.ident) else { abort!(&ty, "Field does not contain a valid ident") };
82 let mut direct = false;
83 let mut as_option = false;
84 for attr in get_patch_attrs(attrs) {
85 let span = attr.span();
86 let content = match syn::parse2(attr) {
87 Ok(PatchParenAttr { content, .. }) => content,
88 Err(e) => abort!(span, "Failed parsing attribute: {}", e),
89 };
90 match content.to_string().as_str() {
91 "direct" => direct = true,
92 "as_option" => as_option = true,
93 a => {
94 abort!(span, "Unknown attribute `{}`", a)
95 }
96 }
97 }
98 if direct && as_option {
99 abort!(&ty, "Only one of `#[patch(direct)]` or `#[patch(as_option)]` may be specified for given field");
100 }
101 if as_option {
102 apply_sets.push(quote! {
103 if self.#name.is_some() {
104 target.#name = self.#name;
105 }
106 })
107 } else if &ident.to_string() == "Option" && !direct {
108 apply_sets.push(quote! {
109 if let Some(val) = self.#name {
110 target.#name = val;
111 }
112 });
113 } else {
114 apply_sets.push(quote! {
115 target.#name = self.#name;
116 });
117 }
118 }
119
120 let apply_content = quote! {
121 #(
122 #apply_sets
123 )*
124 };
125
126 let output = quote! {
127 #(
128 impl ::rust_patch::Patch<#targets> for #ident {
129 fn apply(self, mut target: #targets) -> #targets {
130 { #apply_content }
131 target
132 }
133 }
134 )*
135 };
136
137 proc_macro::TokenStream::from(output)
138}
139
140fn get_patch_attrs(attrs: Vec<Attribute>) -> Vec<TokenStream> {
141 let mut result = Vec::new();
142 for Attribute { path, tokens, .. } in attrs {
143 if path
144 .segments
145 .first()
146 .map(|e| e.ident.to_string())
147 .as_deref()
148 == Some("patch")
149 {
150 result.push(tokens);
151 }
152 }
153 result
154}
155
156fn parse_lit_str<T>(s: &syn::LitStr) -> syn::parse::Result<T>
158where
159 T: Parse,
160{
161 let tokens = spanned_tokens(s)?;
162 syn::parse2(tokens)
163}
164
165fn spanned_tokens(s: &syn::LitStr) -> syn::parse::Result<TokenStream> {
166 let stream = syn::parse_str(&s.value())?;
167 Ok(respan(stream, s.span()))
168}
169
170fn respan(stream: TokenStream, span: Span) -> TokenStream {
171 stream
172 .into_iter()
173 .map(|token| respan_token(token, span))
174 .collect()
175}
176
177fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
178 if let TokenTree::Group(g) = &mut token {
179 *g = Group::new(g.delimiter(), respan(g.stream(), span));
180 }
181 token.set_span(span);
182 token
183}