1extern crate proc_macro;
2use proc_macro2::{Ident, Span, TokenStream};
3use quote::{quote, ToTokens};
4use std::str::FromStr;
5use syn::{
6 meta::ParseNestedMeta, parenthesized, spanned::Spanned, DeriveInput, Error, LitStr, Result,
7 Type,
8};
9
10const PATCH: &str = "patch";
11const NAME: &str = "name";
12const ATTRIBUTE: &str = "attribute";
13const SKIP: &str = "skip";
14const ADDABLE: &str = "addable";
15const ADD: &str = "add";
16
17#[proc_macro_derive(Patch, attributes(patch))]
18pub fn derive_patch(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
19 Patch::from_ast(syn::parse_macro_input!(item as syn::DeriveInput))
20 .unwrap()
21 .to_token_stream()
22 .unwrap()
23 .into()
24}
25
26struct Patch {
27 visibility: syn::Visibility,
28 struct_name: Ident,
29 patch_struct_name: Ident,
30 generics: syn::Generics,
31 attributes: Vec<TokenStream>,
32 fields: Vec<Field>,
33}
34
35#[cfg(feature = "op")]
36enum Addable {
37 Disable,
38 AddTriat,
39 #[cfg(feature = "op")]
40 AddFn(Ident),
41}
42
43struct Field {
44 ident: Option<Ident>,
45 ty: Type,
46 attributes: Vec<TokenStream>,
47 retyped: bool,
48 #[cfg(feature = "op")]
49 addable: Addable,
50}
51
52impl Patch {
53 pub fn to_token_stream(&self) -> Result<TokenStream> {
55 let Patch {
56 visibility,
57 struct_name,
58 patch_struct_name: name,
59 generics,
60 attributes,
61 fields,
62 } = self;
63
64 let patch_struct_fields = fields
65 .iter()
66 .map(|f| f.to_token_stream())
67 .collect::<Result<Vec<_>>>()?;
68 let field_names = fields.iter().map(|f| f.ident.as_ref()).collect::<Vec<_>>();
69
70 let renamed_field_names = fields
71 .iter()
72 .filter(|f| f.retyped)
73 .map(|f| f.ident.as_ref())
74 .collect::<Vec<_>>();
75
76 let original_field_names = fields
77 .iter()
78 .filter(|f| !f.retyped)
79 .map(|f| f.ident.as_ref())
80 .collect::<Vec<_>>();
81
82 let mapped_attributes = attributes
83 .iter()
84 .map(|a| {
85 quote! {
86 #[#a]
87 }
88 })
89 .collect::<Vec<_>>();
90
91 let patch_struct = quote! {
92 #(#mapped_attributes)*
93 #visibility struct #name #generics {
94 #(#patch_struct_fields)*
95 }
96 };
97 let where_clause = &generics.where_clause;
98
99 #[cfg(feature = "status")]
100 let patch_status_impl = quote!(
101 impl #generics struct_patch::traits::PatchStatus for #name #generics #where_clause {
102 fn is_empty(&self) -> bool {
103 #(
104 if self.#field_names.is_some() {
105 return false
106 }
107 )*
108 true
109 }
110 }
111 );
112 #[cfg(not(feature = "status"))]
113 let patch_status_impl = quote!();
114
115 #[cfg(feature = "merge")]
116 let patch_merge_impl = quote!(
117 impl #generics struct_patch::traits::Merge for #name #generics #where_clause {
118 fn merge(self, other: Self) -> Self {
119 Self {
120 #(
121 #renamed_field_names: match (self.#renamed_field_names, other.#renamed_field_names) {
122 (Some(a), Some(b)) => Some(a.merge(b)),
123 (Some(a), None) => Some(a),
124 (None, Some(b)) => Some(b),
125 (None, None) => None,
126 },
127 )*
128 #(
129 #original_field_names: other.#original_field_names.or(self.#original_field_names),
130 )*
131 }
132 }
133 }
134 );
135 #[cfg(not(feature = "merge"))]
136 let patch_merge_impl = quote!();
137
138 #[cfg(feature = "op")]
139 let addable_handles = fields
140 .iter()
141 .map(|f| {
142 match &f.addable {
143 Addable::AddTriat => quote!(
144 Some(a + b)
145 ),
146 Addable::AddFn(f) => {
147 quote!(
148 Some(#f(a, b))
149 )
150 } ,
151 Addable::Disable => quote!(
152 panic!("There are conflict patches, please use `#[patch(addable)]` if you want to add these values.")
153 )
154 }
155 })
156 .collect::<Vec<_>>();
157
158 #[cfg(all(feature = "op", not(feature = "merge")))]
159 let op_impl = quote! {
160 impl #generics core::ops::Shl<#name #generics> for #struct_name #generics #where_clause {
161 type Output = Self;
162
163 fn shl(mut self, rhs: #name #generics) -> Self {
164 struct_patch::traits::Patch::apply(&mut self, rhs);
165 self
166 }
167 }
168
169 impl #generics core::ops::Add<Self> for #name #generics #where_clause {
170 type Output = Self;
171
172 fn add(mut self, rhs: Self) -> Self {
173 Self {
174 #(
175 #renamed_field_names: match (self.#renamed_field_names, rhs.#renamed_field_names) {
176 (Some(a), Some(b)) => {
177 #addable_handles
178 },
179 (Some(a), None) => Some(a),
180 (None, Some(b)) => Some(b),
181 (None, None) => None,
182 },
183 )*
184 #(
185 #original_field_names: match (self.#original_field_names, rhs.#original_field_names) {
186 (Some(a), Some(b)) => {
187 #addable_handles
188 },
189 (Some(a), None) => Some(a),
190 (None, Some(b)) => Some(b),
191 (None, None) => None,
192 },
193 )*
194 }
195 }
196 }
197 };
198
199 #[cfg(feature = "merge")]
200 let op_impl = quote! {
201 impl #generics core::ops::Shl<#name #generics> for #struct_name #generics #where_clause {
202 type Output = Self;
203
204 fn shl(mut self, rhs: #name #generics) -> Self {
205 struct_patch::traits::Patch::apply(&mut self, rhs);
206 self
207 }
208 }
209
210 impl #generics core::ops::Shl<#name #generics> for #name #generics #where_clause {
211 type Output = Self;
212
213 fn shl(mut self, rhs: Self) -> Self {
214 struct_patch::traits::Merge::merge(self, rhs)
215 }
216 }
217
218 impl #generics core::ops::Add<Self> for #name #generics #where_clause {
219 type Output = Self;
220
221 fn add(mut self, rhs: Self) -> Self {
222 Self {
223 #(
224 #renamed_field_names: match (self.#renamed_field_names, rhs.#renamed_field_names) {
225 (Some(a), Some(b)) => {
226 #addable_handles
227 },
228 (Some(a), None) => Some(a),
229 (None, Some(b)) => Some(b),
230 (None, None) => None,
231 },
232 )*
233 #(
234 #original_field_names: match (self.#original_field_names, rhs.#original_field_names) {
235 (Some(a), Some(b)) => {
236 #addable_handles
237 },
238 (Some(a), None) => Some(a),
239 (None, Some(b)) => Some(b),
240 (None, None) => None,
241 },
242 )*
243 }
244 }
245 }
246 };
247
248 #[cfg(not(feature = "op"))]
249 let op_impl = quote!();
250
251 let patch_impl = quote! {
252 impl #generics struct_patch::traits::Patch< #name #generics > for #struct_name #generics #where_clause {
253 fn apply(&mut self, patch: #name #generics) {
254 #(
255 if let Some(v) = patch.#renamed_field_names {
256 self.#renamed_field_names.apply(v);
257 }
258 )*
259 #(
260 if let Some(v) = patch.#original_field_names {
261 self.#original_field_names = v;
262 }
263 )*
264 }
265
266 fn into_patch(self) -> #name #generics {
267 #name {
268 #(
269 #renamed_field_names: Some(self.#renamed_field_names.into_patch()),
270 )*
271 #(
272 #original_field_names: Some(self.#original_field_names),
273 )*
274 }
275 }
276
277 fn into_patch_by_diff(self, previous_struct: Self) -> #name #generics {
278 #name {
279 #(
280 #renamed_field_names: if self.#renamed_field_names != previous_struct.#renamed_field_names {
281 Some(self.#renamed_field_names.into_patch_by_diff(previous_struct.#renamed_field_names))
282 }
283 else {
284 None
285 },
286 )*
287 #(
288 #original_field_names: if self.#original_field_names != previous_struct.#original_field_names {
289 Some(self.#original_field_names)
290 }
291 else {
292 None
293 },
294 )*
295 }
296 }
297
298 fn new_empty_patch() -> #name #generics {
299 #name {
300 #(
301 #field_names: None,
302 )*
303 }
304 }
305 }
306 };
307
308 Ok(quote! {
309 #patch_struct
310
311 #patch_status_impl
312
313 #patch_merge_impl
314
315 #patch_impl
316
317 #op_impl
318 })
319 }
320
321 pub fn from_ast(
323 DeriveInput {
324 ident,
325 data,
326 generics,
327 attrs,
328 vis,
329 }: syn::DeriveInput,
330 ) -> Result<Patch> {
331 let original_fields = if let syn::Data::Struct(syn::DataStruct { fields, .. }) = data {
332 fields
333 } else {
334 return Err(syn::Error::new(
335 ident.span(),
336 "Patch derive only use for struct",
337 ));
338 };
339
340 let mut name = None;
341 let mut attributes = vec![];
342 let mut fields = vec![];
343
344 for attr in attrs {
345 if attr.path().to_string().as_str() != PATCH {
346 continue;
347 }
348
349 if let syn::Meta::List(meta) = &attr.meta {
350 if meta.tokens.is_empty() {
351 continue;
352 }
353 }
354
355 attr.parse_nested_meta(|meta| {
356 let path = meta.path.to_string();
357 match path.as_str() {
358 NAME => {
359 if let Some(lit) = get_lit_str(path, &meta)? {
361 if name.is_some() {
362 return Err(meta
363 .error("The name attribute can't be defined more than once"));
364 }
365 name = Some(lit.parse()?);
366 }
367 }
368 ATTRIBUTE => {
369 let content;
372 parenthesized!(content in meta.input);
373 let attribute: TokenStream = content.parse()?;
374 attributes.push(attribute);
375 }
376 _ => {
377 return Err(meta.error(format_args!(
378 "unknown patch container attribute `{}`",
379 path.replace(' ', "")
380 )));
381 }
382 }
383 Ok(())
384 })?;
385 }
386
387 for field in original_fields {
388 if let Some(f) = Field::from_ast(field)? {
389 fields.push(f);
390 }
391 }
392
393 Ok(Patch {
394 visibility: vis,
395 patch_struct_name: name.unwrap_or({
396 let ts = TokenStream::from_str(&format!("{}Patch", &ident,)).unwrap();
397 let lit = LitStr::new(&ts.to_string(), Span::call_site());
398 lit.parse()?
399 }),
400 struct_name: ident,
401 generics,
402 attributes,
403 fields,
404 })
405 }
406}
407
408impl Field {
409 pub fn to_token_stream(&self) -> Result<TokenStream> {
411 let Field {
412 ident,
413 ty,
414 attributes,
415 ..
416 } = self;
417
418 let attributes = attributes
419 .iter()
420 .map(|a| {
421 quote! {
422 #[#a]
423 }
424 })
425 .collect::<Vec<_>>();
426 match ident {
427 Some(ident) => Ok(quote! {
428 #(#attributes)*
429 pub #ident: Option<#ty>,
430 }),
431 None => Ok(quote! {
432 #(#attributes)*
433 pub Option<#ty>,
434 }),
435 }
436 }
437
438 pub fn from_ast(
440 syn::Field {
441 ident, ty, attrs, ..
442 }: syn::Field,
443 ) -> Result<Option<Field>> {
444 let mut attributes = vec![];
445 let mut field_type = None;
446 let mut skip = false;
447
448 #[cfg(feature = "op")]
449 let mut addable = Addable::Disable;
450
451 for attr in attrs {
452 if attr.path().to_string().as_str() != PATCH {
453 continue;
454 }
455
456 if let syn::Meta::List(meta) = &attr.meta {
457 if meta.tokens.is_empty() {
458 continue;
459 }
460 }
461
462 attr.parse_nested_meta(|meta| {
463 let path = meta.path.to_string();
464 match path.as_str() {
465 SKIP => {
466 skip = true;
468 }
469 ATTRIBUTE => {
470 let content;
472 parenthesized!(content in meta.input);
473 let attribute: TokenStream = content.parse()?;
474 attributes.push(attribute);
475 }
476 NAME => {
477 let expr: LitStr = meta.value()?.parse()?;
479 field_type = Some(expr.parse()?)
480 }
481 #[cfg(feature = "op")]
482 ADDABLE => {
483 addable = Addable::AddTriat;
485 }
486 #[cfg(not(feature = "op"))]
487 ADDABLE => {
488 return Err(syn::Error::new(
489 ident.span(),
490 "`addable` needs `op` feature",
491 ));
492 }
493 #[cfg(feature = "op")]
494 ADD => {
495 let f: Ident = meta.value()?.parse()?;
497 addable = Addable::AddFn(f);
498 }
499 #[cfg(not(feature = "op"))]
500 ADD => {
501 return Err(syn::Error::new(ident.span(), "`add` needs `op` feature"));
502 }
503 _ => {
504 return Err(meta.error(format_args!(
505 "unknown patch field attribute `{}`",
506 path.replace(' ', "")
507 )));
508 }
509 }
510 Ok(())
511 })?;
512 if skip {
513 return Ok(None);
514 }
515 }
516
517 Ok(Some(Field {
518 ident,
519 retyped: field_type.is_some(),
520 ty: field_type.unwrap_or(ty),
521 attributes,
522 #[cfg(feature = "op")]
523 addable,
524 }))
525 }
526}
527
528trait ToStr {
529 fn to_string(&self) -> String;
530}
531
532impl ToStr for syn::Path {
533 fn to_string(&self) -> String {
534 self.to_token_stream().to_string()
535 }
536}
537
538fn get_lit_str(attr_name: String, meta: &ParseNestedMeta) -> syn::Result<Option<syn::LitStr>> {
539 let expr: syn::Expr = meta.value()?.parse()?;
540 let mut value = &expr;
541 while let syn::Expr::Group(e) = value {
542 value = &e.expr;
543 }
544 if let syn::Expr::Lit(syn::ExprLit {
545 lit: syn::Lit::Str(lit),
546 ..
547 }) = value
548 {
549 let suffix = lit.suffix();
550 if !suffix.is_empty() {
551 return Err(Error::new(
552 lit.span(),
553 format!("unexpected suffix `{}` on string literal", suffix),
554 ));
555 }
556 Ok(Some(lit.clone()))
557 } else {
558 Err(Error::new(
559 expr.span(),
560 format!(
561 "expected serde {} attribute to be a string: `{} = \"...\"`",
562 attr_name, attr_name
563 ),
564 ))
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use pretty_assertions_sorted::assert_eq_sorted;
571 use syn::token::Pub;
572
573 use super::*;
574
575 #[test]
576 fn parse_patch() {
577 let input = quote! {
579 #[derive(Patch)]
580 #[patch(name = "MyPatch", attribute(derive(Debug, PartialEq, Clone, Serialize, Deserialize)))]
581 pub struct Item {
582 #[patch(name = "SubItemPatch")]
583 pub field1: SubItem,
584 #[patch(skip)]
585 pub field2: Option<String>,
586 }
587 };
588 let expected = Patch {
589 visibility: syn::Visibility::Public(Pub::default()),
590 struct_name: syn::Ident::new("Item", Span::call_site()),
591 patch_struct_name: syn::Ident::new("MyPatch", Span::call_site()),
592 generics: syn::Generics::default(),
593 attributes: vec![quote! { derive(Debug, PartialEq, Clone, Serialize, Deserialize) }],
594 fields: vec![Field {
595 ident: Some(syn::Ident::new("field1", Span::call_site())),
596 ty: LitStr::new("SubItemPatch", Span::call_site())
597 .parse()
598 .unwrap(),
599 attributes: vec![],
600 retyped: true,
601 #[cfg(feature = "op")]
602 addable: Addable::Disable,
603 }],
604 };
605 let result = Patch::from_ast(syn::parse2(input).unwrap()).unwrap();
606 assert_eq_sorted!(
607 format!("{:?}", result.to_token_stream()),
608 format!("{:?}", expected.to_token_stream())
609 );
610 }
611}