uniplate_derive/
lib.rs

1mod ast;
2mod prelude;
3mod state;
4
5use std::collections::VecDeque;
6
7use prelude::*;
8use quote::format_ident;
9use syn::parse_macro_input;
10
11#[proc_macro_derive(Uniplate, attributes(uniplate, biplate))]
12pub fn uniplate_derive(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as ast::DeriveInput);
14    let mut state: ParserState = ParserState::new(input.clone());
15
16    let mut out_tokens: Vec<TokenStream2> = Vec::new();
17    out_tokens.push(quote! {
18        use std::borrow::Borrow as _;
19    });
20
21    while state.next_instance().is_some() {
22        out_tokens.push(match &state.current_instance {
23            Some(ast::InstanceMeta::Uniplate(_)) => derive_a_uniplate(&mut state),
24            Some(ast::InstanceMeta::Biplate(_)) => derive_a_biplate(&mut state),
25            _ => unreachable!(),
26        });
27    }
28
29    out_tokens.into_iter().collect::<TokenStream2>().into()
30}
31
32fn derive_a_uniplate(state: &mut ParserState) -> TokenStream2 {
33    let from = state.from.to_token_stream();
34    let tokens: TokenStream2 = match state.data.clone() {
35        ast::Data::DataEnum(x) => _derive_a_enum_uniplate(state, x),
36        ast::Data::DataStruct(x) => _derive_a_struct_uniplate(state, x),
37    };
38
39    let mut generics = state.data.generics().clone();
40    for (_, bounds) in generics.type_parameters.iter_mut() {
41        // Add 'static bounds to all generic type parameters.
42        bounds.push(syn::TypeParamBound::Verbatim(quote!('static)));
43    }
44
45    let impl_bounds = generics.impl_parameters();
46    let where_clause = generics.impl_type_where_block();
47    quote! {
48        impl<#impl_bounds> ::uniplate::Uniplate for #from #where_clause {
49            fn uniplate(&self) -> (::uniplate::Tree<#from>, Box<dyn Fn(::uniplate::Tree<#from>) -> #from>) {
50                #tokens
51            }
52        }
53    }
54}
55
56fn _derive_a_enum_uniplate(state: &mut ParserState, data: ast::DataEnum) -> TokenStream2 {
57    let mut variant_tokens = VecDeque::<TokenStream2>::new();
58    for variant in data.variants {
59        let fields = &variant.fields;
60        let field_idents: Vec<_> = fields.idents().collect();
61
62        let field_defs: Vec<_> = fields
63            .defs()
64            .map(|(mem, typ)| _derive_for_field_enum(state, typ, &mem))
65            .collect();
66
67        let children_def = _derive_children(state, fields);
68        let ctx_def = _derive_ctx(state, fields, Some(&variant.ident));
69        let ident = variant.ident;
70        let enum_ident = state.data.ident();
71
72        match variant.fields {
73            ast::Fields::Struct(_) => {
74                variant_tokens.push_back(quote! {
75                    #enum_ident::#ident{#(#field_idents),*} => {
76                        #(#field_defs)*
77
78                        #children_def
79
80                        #ctx_def
81
82                        (children,ctx)
83                    },
84                });
85            }
86
87            ast::Fields::Tuple(_) => {
88                variant_tokens.push_back(quote! {
89                    #enum_ident::#ident(#(#field_idents),*) => {
90                        #(#field_defs)*
91
92                        #children_def
93
94                        #ctx_def
95
96                        (children,ctx)
97                    },
98                });
99            }
100            ast::Fields::Unit => {
101                variant_tokens.push_back(quote! {
102                    #enum_ident::#ident => {
103                        #children_def
104
105                        #ctx_def
106
107                        (children,ctx)
108                    },
109                });
110            }
111        }
112    }
113
114    let variant_tokens = variant_tokens.iter();
115    quote! {
116        match self {
117            #(#variant_tokens)*
118        }
119    }
120}
121
122fn _derive_a_struct_uniplate(state: &mut ParserState, data: ast::DataStruct) -> TokenStream2 {
123    let struct_ident = state.data.ident();
124    if data.fields.is_empty() {
125        // Unit-like or empty struct
126        return quote! {
127            (::uniplate::Tree::Zero,Box::new(|_| #struct_ident))
128        };
129    }
130
131    let field_defs: Vec<_> = data
132        .fields
133        .defs()
134        .map(|(mem, typ)| _derive_for_field_struct(state, typ, mem))
135        .collect();
136    let children_def = _derive_children(state, &data.fields);
137    let ctx_def = _derive_ctx(state, &data.fields, None);
138
139    quote! {
140        #(#field_defs)*
141
142        #children_def
143
144        #ctx_def
145
146        (children,ctx)
147    }
148}
149
150fn _derive_for_field_enum(
151    state: &mut ParserState,
152    field_type: &ast::Type,
153    member: &syn::Member,
154) -> TokenStream2 {
155    // the identifier used in the match clause.
156    // either _1, or the field name.
157    let match_ident = match member {
158        syn::Member::Named(ident) => ident.clone(),
159        syn::Member::Unnamed(index) => format_ident!("_{}", index),
160    };
161
162    let children_ident = format_ident!("_{}_children", member);
163    let ctx_ident = format_ident!("_{}_ctx", member);
164
165    let to_t = state.to.clone().expect("").to_token_stream();
166
167    match field_type {
168        ast::Type::BoxedBasic(_) => {
169            quote! {
170                let (#children_ident,#ctx_ident) = ::uniplate::spez::try_biplate_to!((**#match_ident).clone(), #to_t);
171            }
172        }
173        ast::Type::Basic(_) => {
174            quote! {
175                let (#children_ident,#ctx_ident) = ::uniplate::spez::try_biplate_to!(#match_ident.clone(), #to_t);
176            }
177        }
178        ast::Type::Tuple(tuple_type) => {
179            // destructure the tuple
180            let tuple_field_idents =
181                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}", member));
182            let destructure_tuple = quote! {
183                    let (#(#tuple_field_idents),*) = #match_ident;
184            };
185
186            // call biplate on each tuple field
187            let call_biplate_for_each_field = tuple_type.fields.iter().enumerate().map(|(i,_)| {
188                let field_ident = format_ident!("_{}_tuple_field_{i}",member);
189                let field_children_ident = format_ident!("_{}_tuple_field_{i}_children",member);
190                let field_ctx_ident = format_ident!("_{}_tuple_field_{i}_ctx", member);
191
192                // let index = syn::Index::from(i);
193                quote!{
194                    let (#field_children_ident,#field_ctx_ident) = ::uniplate::spez::try_biplate_to!(#field_ident.clone(), #to_t);
195                }
196            });
197
198            let tuple_field_children_idents =
199                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_children", member));
200            let tuple_field_ctx_idents =
201                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_ctx", member));
202
203            // build the children tree by combining each fields' tree
204            let build_child_tree = quote! {
205                let #children_ident = ::uniplate::Tree::Many(::std::collections::VecDeque::from([#(#tuple_field_children_idents),*]));
206            };
207
208            let index = (0..tuple_type.n).map(syn::Index::from);
209
210            // build the context function
211            let build_child_ctx = quote! {
212                let #ctx_ident = Box::new(move |x| {
213                    let ::uniplate::Tree::Many(xs) = x else {
214                        panic!()
215                    };
216
217                    (#(#tuple_field_ctx_idents(xs[#index].clone())),*)
218                });
219            };
220
221            quote! {
222                #destructure_tuple
223                #(#call_biplate_for_each_field);*
224                #build_child_tree
225                #build_child_ctx
226            }
227        }
228        ast::Type::BoxedTuple(tuple_type) => {
229            // destructure the tuple
230            let tuple_field_idents =
231                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}", member));
232            let destructure_tuple = quote! {
233                    let (#(#tuple_field_idents),*) = (**#match_ident).clone();
234            };
235
236            // call biplate on each tuple field
237            let call_biplate_for_each_field = tuple_type.fields.iter().enumerate().map(|(i,_)| {
238                let field_ident = format_ident!("_{}_tuple_field_{i}",member);
239                let field_children_ident = format_ident!("_{}_tuple_field_{i}_children",member);
240                let field_ctx_ident = format_ident!("_{}_tuple_field_{i}_ctx", member);            
241
242                // let index = syn::Index::from(i);
243                quote!{
244                    let (#field_children_ident,#field_ctx_ident) = ::uniplate::spez::try_biplate_to!(#field_ident.clone(), #to_t);
245                }
246            });
247
248            let tuple_field_children_idents =
249                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_children", member));
250            let tuple_field_ctx_idents =
251                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_ctx", member));
252
253            // build the children tree by combining each fields' tree
254            let build_child_tree = quote! {
255                let #children_ident = ::uniplate::Tree::Many(::std::collections::VecDeque::from([#(#tuple_field_children_idents),*]));
256            };
257
258            let index = (0..tuple_type.n).map(syn::Index::from);
259
260            // build the context function
261            let build_child_ctx = quote! {
262                let #ctx_ident = Box::new(move |x| {
263                    let ::uniplate::Tree::Many(xs) = x else {
264                        panic!()
265                    };
266
267                    (#(#tuple_field_ctx_idents(xs[#index].clone())),*)
268                });
269            };
270
271            quote! {
272                #destructure_tuple
273                #(#call_biplate_for_each_field);*
274                #build_child_tree
275                #build_child_ctx
276            }
277        }
278    }
279}
280
281fn _derive_for_field_struct(
282    state: &mut ParserState,
283    field_type: &ast::Type,
284    member: syn::Member,
285) -> TokenStream2 {
286    let children_ident = format_ident!("_{}_children", member);
287    let ctx_ident = format_ident!("_{}_ctx", member);
288
289    let to_t = state.to.clone().expect("").to_token_stream();
290
291    match field_type {
292        ast::Type::BoxedBasic(_) => {
293            quote! {
294                let (#children_ident,#ctx_ident) = ::uniplate::spez::try_biplate_to!((*self.#member).clone(), #to_t);
295            }
296        }
297        ast::Type::Basic(_) => {
298            quote! {
299                let (#children_ident,#ctx_ident) = ::uniplate::try_biplate_to!(self.#member.clone(), #to_t);
300            }
301        }
302        ast::Type::Tuple(tuple_type) => {
303            // destructure the tuple
304            let tuple_field_idents =
305                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}", member));
306            let destructure_tuple = quote! {
307                    let (#(#tuple_field_idents),*) = self.#member.clone();
308            };
309
310            // call biplate on each tuple field
311            let call_biplate_for_each_field = tuple_type.fields.iter().enumerate().map(|(i,_)| {
312                let field_ident = format_ident!("_{}_tuple_field_{i}",member);
313                let field_children_ident = format_ident!("_{}_tuple_field_{i}_children",member);
314                let field_ctx_ident = format_ident!("_{}_tuple_field_{i}_ctx", member);
315
316                // let index = syn::Index::from(i);
317                quote!{
318                    let (#field_children_ident,#field_ctx_ident) = ::uniplate::spez::try_biplate_to!(#field_ident.clone(), #to_t);
319                }
320            });
321
322            let tuple_field_children_idents =
323                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_children", member));
324            let tuple_field_ctx_idents =
325                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_ctx", member));
326
327            // build the children tree by combining each fields' tree
328            let build_child_tree = quote! {
329                let #children_ident = ::uniplate::Tree::Many(::std::collections::VecDeque::from([#(#tuple_field_children_idents),*]));
330            };
331
332            let index = (0..tuple_type.n).map(syn::Index::from);
333
334            // build the context function
335            let build_child_ctx = quote! {
336                let #ctx_ident = Box::new(move |x| {
337                    let ::uniplate::Tree::Many(xs) = x else {
338                        panic!()
339                    };
340
341                    (#(#tuple_field_ctx_idents(xs[#index].clone())),*)
342                });
343            };
344
345            quote! {
346                #destructure_tuple
347                #(#call_biplate_for_each_field);*
348                #build_child_tree
349                #build_child_ctx
350            }
351        }
352        ast::Type::BoxedTuple(tuple_type) => {
353            // destructure the tuple
354            let tuple_field_idents =
355                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}", member));
356            let destructure_tuple = quote! {
357                    let (#(#tuple_field_idents),*) = (*self.#member).clone();
358            };
359
360            // call biplate on each tuple field
361            let call_biplate_for_each_field = tuple_type.fields.iter().enumerate().map(|(i,_)| {
362                let field_ident = format_ident!("_{}_tuple_field_{i}",member);
363                let field_children_ident = format_ident!("_{}_tuple_field_{i}_children",member);
364                let field_ctx_ident = format_ident!("_{}_tuple_field_{i}_ctx", member);
365
366                // let index = syn::Index::from(i);
367                quote!{
368                    let (#field_children_ident,#field_ctx_ident) = ::uniplate::spez::try_biplate_to!(#field_ident.clone(), #to_t);
369                }
370            });
371
372            let tuple_field_children_idents =
373                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_children", member));
374            let tuple_field_ctx_idents =
375                (0..tuple_type.n).map(|i| format_ident!("_{}_tuple_field_{i}_ctx", member));
376
377            // build the children tree by combining each fields' tree
378            let build_child_tree = quote! {
379                let #children_ident = ::uniplate::Tree::Many(::std::collections::VecDeque::from([#(#tuple_field_children_idents),*]));
380            };
381
382            let index = (0..tuple_type.n).map(syn::Index::from);
383
384            // build the context function
385            let build_child_ctx = quote! {
386                let #ctx_ident = Box::new(move |x| {
387                    let ::uniplate::Tree::Many(xs) = x else {
388                        panic!()
389                    };
390
391                    (#(#tuple_field_ctx_idents(xs[#index].clone())),*)
392                });
393            };
394
395            quote! {
396                #destructure_tuple
397                #(#call_biplate_for_each_field);*
398                #build_child_tree
399                #build_child_ctx
400            }
401        }
402    }
403}
404
405fn _derive_children(_state: &mut ParserState, fields: &ast::Fields) -> TokenStream2 {
406    let mut subtrees: VecDeque<TokenStream2> = VecDeque::new();
407    for (member, _) in fields.defs() {
408        subtrees.push_back({
409            let children_ident = format_ident!("_{}_children", member);
410            quote!(#children_ident)
411        });
412    }
413
414    match subtrees.len() {
415        0 => quote! {let children = ::uniplate::Tree::Zero;},
416        _ => {
417            let subtrees = subtrees.iter();
418            quote! {let children = ::uniplate::Tree::Many(::std::collections::VecDeque::from([#(#subtrees),*]));}
419        }
420    }
421}
422
423fn _derive_ctx(
424    state: &mut ParserState,
425    fields: &ast::Fields,
426    var_ident: Option<&syn::Ident>,
427) -> TokenStream2 {
428    let field_ctxs: Vec<_> = fields
429        .defs()
430        .enumerate()
431        .map(|(i, (mem, typ))| match typ {
432            ast::Type::Basic(_) | ast::Type::Tuple(_) => {
433                let ctx_ident = format_ident!("_{}_ctx", mem);
434                quote! {#ctx_ident(x[#i].clone())}
435            }
436
437            ast::Type::BoxedBasic(_) | ast::Type::BoxedTuple(_) => {
438                let ctx_ident = format_ident!("_{}_ctx", mem);
439                quote! {Box::new(#ctx_ident(x[#i].clone()))}
440            }
441        })
442        .collect();
443
444    let data_ident = state.data.ident(); // The enum or struct name
445    let typ = state.to.clone();
446
447    // If this is an enum, use the passed variant identifier
448    let construct_ident = match var_ident {
449        Some(var) => quote! {#data_ident::#var},
450        None => quote! {#data_ident},
451    };
452    if fields.is_empty() {
453        quote! {
454            let ctx = Box::new(move |x: ::uniplate::Tree<#typ>| {
455                let ::uniplate::Tree::Zero = x else { panic!()};
456                #construct_ident
457            });
458        }
459    } else {
460        // If this is an enum, use the passed variant identifier
461        let construct_ident = match var_ident {
462            Some(var) => quote! {#data_ident::#var},
463            None => quote! {#data_ident},
464        };
465        let construct = match fields {
466            ast::Fields::Tuple(_) => {
467                quote! {
468                    #construct_ident(#(#field_ctxs),*)
469                }
470            }
471            ast::Fields::Struct(_) => {
472                let items = std::iter::zip(fields.idents(), field_ctxs.iter())
473                    .map(|(ident, ctx)| quote! {#ident: #ctx});
474                quote! {
475                    #construct_ident {
476                        #(#items),*
477                    }
478                }
479            }
480            ast::Fields::Unit => quote! {#var_ident},
481        };
482        quote! {
483            let ctx = Box::new(move |x: ::uniplate::Tree<#typ>| {
484                let ::uniplate::Tree::Many(x) = x else { panic!()};
485                #construct
486        });}
487    }
488}
489
490fn derive_a_biplate(state: &mut ParserState) -> TokenStream2 {
491    let from = state.from.to_token_stream();
492    let to = state.to.to_token_stream();
493
494    if from.to_string() == to.to_string() {
495        return _derive_identity_biplate(state, from);
496    }
497
498    let tokens: TokenStream2 = match state.data.clone() {
499        ast::Data::DataEnum(x) => _derive_a_enum_uniplate(state, x),
500        ast::Data::DataStruct(x) => _derive_a_struct_uniplate(state, x),
501    };
502
503    let mut generics = state.data.generics().clone();
504    for (typ, bounds) in generics.type_parameters.iter_mut() {
505        // Add 'static bounds to all generic type parameters.
506        bounds.push(syn::TypeParamBound::Verbatim(quote!('static)));
507
508        // If we are deriving Biplate<T>, T must be Uniplate
509        if to.to_string() == typ.to_token_stream().to_string() {
510            bounds.push(syn::TypeParamBound::Verbatim(quote!(Uniplate)));
511        }
512    }
513
514    let impl_bounds = generics.impl_parameters();
515    let where_clause = generics.impl_type_where_block();
516
517    quote! {
518        impl<#impl_bounds> ::uniplate::Biplate<#to> for #from #where_clause{
519            fn biplate(&self) -> (::uniplate::Tree<#to>, Box<dyn Fn(::uniplate::Tree<#to>) -> #from>) {
520                #tokens
521            }
522        }
523    }
524}
525
526fn _derive_identity_biplate(state: &mut ParserState, from: TokenStream2) -> TokenStream2 {
527    let mut generics = state.data.generics().clone();
528    // Add 'static bounds to all generic type parameters.
529    for (_, bounds) in generics.type_parameters.iter_mut() {
530        bounds.push(syn::TypeParamBound::Verbatim(quote!('static)));
531    }
532
533    let impl_bounds = generics.impl_parameters();
534    let where_clause = generics.impl_type_where_block();
535
536    quote! {
537        impl<#impl_bounds> ::uniplate::Biplate<#from> for #from #where_clause{
538            fn biplate(&self) -> (::uniplate::Tree<#from>, Box<dyn Fn(::uniplate::Tree<#from>) -> #from>) {
539                let val = self.clone();
540                (::uniplate::Tree::One(val.clone()),Box::new(move |x| {
541                    let ::uniplate::Tree::One(x) = x else {todo!()};
542                    x
543                }))
544            }
545        }
546    }
547}