1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
use proc_macro2::{Group, Ident, Literal, Span, TokenStream, TokenTree};
use proc_macro_error::{abort, abort_call_site, proc_macro_error};
use quote::quote;
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
token, Attribute, Data, DataStruct, DeriveInput, Fields, LitStr, Token, Type, TypePath,
};
struct PatchEqAttr {
_eq_token: Token![=],
path: LitStr,
}
impl Parse for PatchEqAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Self {
_eq_token: input.parse()?,
path: input.parse()?,
})
}
}
struct PatchParenAttr {
_paren_token: token::Paren,
content: Ident,
}
impl Parse for PatchParenAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let content;
Ok(Self {
_paren_token: parenthesized!(content in input),
content: content.parse()?,
})
}
}
#[proc_macro_derive(Patch, attributes(patch))]
#[proc_macro_error]
pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let ident = input.ident;
let Data::Struct(DataStruct { fields, ..}) = input.data else { abort_call_site!("Patch can only be derived on structs") };
let fields = match fields {
Fields::Named(f) => f
.named
.into_pairs()
.map(|p| p.into_value())
.map(|f| (TokenTree::from(f.ident.unwrap()), f.ty, f.attrs))
.collect::<Vec<_>>(),
Fields::Unnamed(f) => f
.unnamed
.into_pairs()
.map(|p| p.into_value())
.enumerate()
.map(|(i, f)| {
(
TokenTree::from(Literal::u32_unsuffixed(i as u32)),
f.ty,
f.attrs,
)
})
.collect::<Vec<_>>(),
Fields::Unit => Vec::new(),
};
let mut targets = Vec::new();
for patch_target in get_patch_attrs(input.attrs) {
let span = patch_target.span();
let Ok(PatchEqAttr { path, ..}) = syn::parse2(patch_target) else { abort!(span, r#"Patch target must be specified in the form `#[patch = "path::to::Type"]`"#) };
let Ok(path) = parse_lit_str::<TypePath>(&path) else { abort!(&path, "`{}` is not a valid path", path.value())};
targets.push(path);
}
let mut apply_sets = Vec::new();
for (name, ty, attrs) in fields {
let Type::Path(TypePath { path, .. }) = &ty else { abort!(&ty, "Failed parsing field type as type path") };
let Some(ident) = path.segments.first().map(|e| &e.ident) else { abort!(&ty, "Field does not contain a valid ident") };
let mut direct = false;
let mut as_option = false;
for attr in get_patch_attrs(attrs) {
let span = attr.span();
let Ok(PatchParenAttr { content, .. }) = syn::parse2(attr) else { abort!(span, "Failed parsing field attribute") };
match content.to_string().as_str() {
"direct" => direct = true,
"as_option" => as_option = true,
_a => {
abort!(span, "Unknown attribute {a}")
}
}
}
if direct && as_option {
abort!(&ty, "Only one of `#[patch(direct)]` or `#[patch(as_option)]` may be specified for given field");
}
if as_option {
apply_sets.push(quote! {
if self.#name.is_some() {
target.#name = self.#name;
}
})
} else if &ident.to_string() == "Option" && !direct {
apply_sets.push(quote! {
if let Some(val) = self.#name {
target.#name = val;
}
});
} else {
apply_sets.push(quote! {
target.#name = self.#name;
});
}
}
let apply_content = quote! {
#(
#apply_sets
)*
};
let output = quote! {
#(
impl ::rust_patch::Patch<#targets> for #ident {
fn apply(self, mut target: #targets) -> #targets {
{ #apply_content }
target
}
}
)*
};
proc_macro::TokenStream::from(output)
}
fn get_patch_attrs(attrs: Vec<Attribute>) -> Vec<TokenStream> {
let mut result = Vec::new();
for Attribute { path, tokens, .. } in attrs {
if path
.segments
.first()
.map(|e| e.ident.to_string())
.as_deref()
== Some("patch")
{
result.push(tokens);
}
}
result
}
fn parse_lit_str<T>(s: &syn::LitStr) -> syn::parse::Result<T>
where
T: Parse,
{
let tokens = spanned_tokens(s)?;
syn::parse2(tokens)
}
fn spanned_tokens(s: &syn::LitStr) -> syn::parse::Result<TokenStream> {
let stream = syn::parse_str(&s.value())?;
Ok(respan(stream, s.span()))
}
fn respan(stream: TokenStream, span: Span) -> TokenStream {
stream
.into_iter()
.map(|token| respan_token(token, span))
.collect()
}
fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
if let TokenTree::Group(g) = &mut token {
*g = Group::new(g.delimiter(), respan(g.stream(), span));
}
token.set_span(span);
token
}