1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use proc_macro2::{Group, Literal, Span, TokenStream, TokenTree};
use proc_macro_error::proc_macro_error;
use proc_macro_error::{abort, abort_call_site};
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, Attribute, Data, DataStruct, LitStr, Token, Type, TypePath};
use syn::{DeriveInput, Fields};
struct PatchAttr {
_eq_token: Token![=],
path: LitStr,
}
impl Parse for PatchAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Self {
_eq_token: input.parse()?,
path: input.parse()?,
})
}
}
#[proc_macro_derive(Patch, attributes(patch))]
#[proc_macro_error]
pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let ident = input.ident;
let Data::Struct(DataStruct { fields, ..}) = input.data else { abort_call_site!("Patch can only be derived on structs") };
let fields = match fields {
Fields::Named(f) => f
.named
.into_pairs()
.map(|p| p.into_value())
.map(|f| (TokenTree::from(f.ident.unwrap()), f.ty))
.collect::<Vec<_>>(),
Fields::Unnamed(f) => f
.unnamed
.into_pairs()
.map(|p| p.into_value())
.enumerate()
.map(|(i, f)| (TokenTree::from(Literal::u32_unsuffixed(i as u32)), f.ty))
.collect::<Vec<_>>(),
Fields::Unit => Vec::new(),
};
let mut targets = Vec::new();
for Attribute { path, tokens, .. } in input.attrs {
if path
.segments
.first()
.map(|e| e.ident.to_string())
.as_deref()
== Some("patch")
{
let Ok(PatchAttr { path, ..}) = syn::parse2(tokens) else { abort!(&path, r#"Patch target must be specified in the form `#[patch = "path::to::Type"]`"#) };
let Ok(path) = parse_lit_str::<TypePath>(&path) else { abort!(&path, "`{}` is not a valid path", path.value())};
targets.push(path);
}
}
let mut apply_sets = Vec::new();
for (name, ty) in fields {
if let Type::Path(TypePath { path, .. }) = &ty {
let Some(ident) = path.segments.first().map(|e| &e.ident) else { abort!(&ty, "Failed parsing field") };
if &ident.to_string() == "Option" {
apply_sets.push(quote! {
if let Some(val) = self.#name {
target.#name = val;
}
});
} else {
apply_sets.push(quote! {
target.#name = self.#name;
});
}
}
}
let apply_content = quote! {
#(
#apply_sets
)*
};
let output = quote! {
#(
impl ::rust_patch::Patch<#targets> for #ident {
fn apply(self, mut target: #targets) -> #targets {
{ #apply_content }
target
}
}
)*
};
proc_macro::TokenStream::from(output)
}
fn parse_lit_str<T>(s: &syn::LitStr) -> syn::parse::Result<T>
where
T: Parse,
{
let tokens = spanned_tokens(s)?;
syn::parse2(tokens)
}
fn spanned_tokens(s: &syn::LitStr) -> syn::parse::Result<TokenStream> {
let stream = syn::parse_str(&s.value())?;
Ok(respan(stream, s.span()))
}
fn respan(stream: TokenStream, span: Span) -> TokenStream {
stream
.into_iter()
.map(|token| respan_token(token, span))
.collect()
}
fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
if let TokenTree::Group(g) = &mut token {
*g = Group::new(g.delimiter(), respan(g.stream(), span));
}
token.set_span(span);
token
}