1use std::collections::HashSet;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{
6 parse_quote, DataStruct, DeriveInput, Fields, Ident, Lit, Meta, Path, WhereClause,
7 WherePredicate,
8};
9
10fn get_rename_attribute(ast: &DeriveInput) -> Option<String> {
11 for attr in &ast.attrs {
12 if attr.path().is_ident("visit") {
14 if let Ok(meta_list) = attr.meta.require_list() {
15 if let Ok(Meta::NameValue(nv)) = syn::parse2::<Meta>(meta_list.tokens.clone()) {
16 if nv.path.is_ident("rename") {
17 if let syn::Expr::Lit(lit) = &nv.value {
18 if let Lit::Str(s) = &lit.lit {
19 return Some(s.value());
20 }
21 }
22 }
23 }
24 }
25 }
26 if attr.path().is_ident("serde") {
28 if let Ok(meta_list) = attr.meta.require_list() {
29 if let Ok(Meta::NameValue(nv)) = syn::parse2::<Meta>(meta_list.tokens.clone()) {
30 if nv.path.is_ident("rename") {
31 if let syn::Expr::Lit(lit) = &nv.value {
32 if let Lit::Str(s) = &lit.lit {
33 return Some(s.value());
34 }
35 }
36 }
37 }
38 }
39 }
40 }
41 None
42}
43
44fn make_impl(
45 input: &DeriveInput,
46 fields: &Fields,
47 trait_path_fields: &Path,
48 trait_path: &Path,
49 named: Option<&Path>,
50 sync: bool,
51 is_static: bool,
52) -> TokenStream {
53 let ident = &input.ident;
54
55 let (_, ty_generics, _) = &input.generics.split_for_impl();
56
57 let mut generics = input.generics.clone();
58
59 generics.params.push(syn::parse_quote! { __visit_rs__V });
60
61 let predicates = &mut generics
62 .where_clause
63 .get_or_insert(WhereClause {
64 predicates: Default::default(),
65 where_token: Default::default(),
66 })
67 .predicates;
68
69 predicates.push(syn::parse_quote! { __visit_rs__V: visit_rs::Visitor });
70 if sync {
71 predicates.extend(fields.iter().map(|f| &f.ty).map(|t| -> WherePredicate {
72 parse_quote! { #t: Sync }
73 }));
74 }
75
76 let mut ty_set = HashSet::new();
77 for (_, field) in field_iter(fields) {
78 let ty = &field.ty;
79 if !ty_set.insert(ty) {
80 continue;
81 }
82 if let Some(named) = named {
83 if is_static {
84 predicates.push(
85 syn::parse_quote! { for<'__visit_rs__named> #named <'__visit_rs__named, visit_rs::Static<#ty>>: #trait_path<__visit_rs__V> },
86 );
87 } else {
88 predicates.push(syn::parse_quote! { for<'__visit_rs__named> #named <'__visit_rs__named, #ty>: #trait_path<__visit_rs__V> });
89 }
90 } else {
91 if is_static {
92 predicates
93 .push(syn::parse_quote! { visit_rs::Static<#ty>: #trait_path<__visit_rs__V> });
94 } else {
95 predicates.push(syn::parse_quote! { #ty: #trait_path<__visit_rs__V> });
96 }
97 }
98 }
99
100 let (impl_generics, _, where_clause) = generics.split_for_impl();
101
102 quote! {
103 impl #impl_generics #trait_path_fields<__visit_rs__V> for #ident #ty_generics #where_clause
104 }
105}
106
107fn field_iter(fields: &Fields) -> impl Iterator<Item = (usize, &syn::Field)> {
108 fields.iter().enumerate().filter(|(_, field)| {
109 !field.attrs.iter().any(|attr| {
110 attr.path().is_ident("visit")
111 && attr.parse_args::<Ident>().map_or(false, |id| id == "skip")
112 })
113 })
114}
115
116fn field_idx_iter(fields: &Fields) -> impl Iterator<Item = TokenStream> {
117 field_iter(fields).map(|(index, field)| {
118 let field_name = &field.ident;
119 if let Some(name) = field_name {
120 quote! { #name }
121 } else {
122 let index = syn::Index::from(index);
123 quote! { #index }
124 }
125 })
126}
127
128fn field_name_idx_iter(fields: &syn::Fields) -> impl Iterator<Item = (TokenStream, TokenStream)> {
129 field_iter(fields).map(|(index, field)| {
130 let field_name = &field.ident;
131 let idx = if let Some(name) = field_name {
132 quote! { #name }
133 } else {
134 let index = syn::Index::from(index);
135 quote! { #index }
136 };
137 let name = if let Some(name) = field_name {
138 quote! { Some(stringify!(#name)) }
139 } else {
140 quote! { None }
141 };
142 (name, idx)
143 })
144}
145
146#[proc_macro_derive(VisitFields, attributes(visit))]
147pub fn derive_visit_fields_(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
148 let ast: DeriveInput = syn::parse(input).unwrap();
149
150 let syn::Data::Struct(data) = &ast.data else {
151 let span = match &ast.data {
152 syn::Data::Enum(data) => data.enum_token.span,
153 syn::Data::Union(data) => data.union_token.span,
154 _ => Span::call_site(),
155 };
156 return syn::Error::new(span, "VisitFields can only be derived for structs")
157 .to_compile_error()
158 .into();
159 };
160
161 let all_impls = match (|| {
162 Ok::<_, syn::Error>([
163 derive_struct_info(&ast, data)?,
164 derive_visit_fields(&ast, data)?,
165 derive_visit_fields_covered(&ast, data)?,
166 derive_visit_fields_async(&ast, data)?,
167 derive_visit_fields_covered_async(&ast, data)?,
168 derive_visit_fields_named(&ast, data)?,
169 derive_visit_fields_named_async(&ast, data)?,
170 derive_visit_fields_static(&ast, data)?,
171 derive_visit_fields_static_async(&ast, data)?,
172 derive_visit_fields_static_named(&ast, data)?,
173 derive_visit_fields_static_named_async(&ast, data)?,
174 ])
175 })() {
176 Ok(a) => a,
177 Err(e) => return e.to_compile_error().into(),
178 };
179
180 proc_macro::TokenStream::from(quote! {
183 #(#all_impls)*
184 })
185 }
187
188fn derive_struct_info(ast: &DeriveInput, data: &DataStruct) -> Result<TokenStream, syn::Error> {
189 let ident = &ast.ident;
190 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
191
192 let named_fields = matches!(data.fields, Fields::Named(_));
193 let field_count = field_iter(&data.fields).count();
194
195 let name = get_rename_attribute(ast).unwrap_or_else(|| ident.to_string());
196
197 Ok(quote! {
198 impl #impl_generics visit_rs::StructInfo for #ident #ty_generics #where_clause {
199 const DATA: visit_rs::StructInfoData = visit_rs::StructInfoData {
200 name: #name,
201 named_fields: #named_fields,
202 field_count: #field_count,
203 };
204 }
205 })
206}
207
208fn derive_visit_fields(ast: &DeriveInput, data: &DataStruct) -> Result<TokenStream, syn::Error> {
209 let impl_t = make_impl(
210 &ast,
211 &data.fields,
212 &syn::parse_quote! { visit_rs::VisitFields },
213 &syn::parse_quote! { visit_rs::Visit },
214 None,
215 false,
216 false,
217 );
218
219 let visit_fields_impl = field_idx_iter(&data.fields).enumerate().map(|(num, idx)| {
220 quote! {
221 #num => {
222 pos += 1;
223 Some(visit_rs::Visit::visit(&self.#idx, visitor))
224 }
225 }
226 });
227
228 Ok(quote! {
229 #impl_t {
230 fn visit_fields<'__visit_rs__a>(
231 &'__visit_rs__a self,
232 visitor: &'__visit_rs__a mut __visit_rs__V
233 ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> {
234 std::iter::from_fn({
235 let mut pos = 0;
236 move || match pos {
237 #(#visit_fields_impl)*
238 _ => None,
239 }
240 })
241 }
242 }
243 })
244}
245
246fn derive_visit_fields_covered(
247 ast: &DeriveInput,
248 data: &DataStruct,
249) -> Result<TokenStream, syn::Error> {
250 let impl_t = make_impl(
251 &ast,
252 &data.fields,
253 &syn::parse_quote! { visit_rs::VisitFieldsCovered },
254 &syn::parse_quote! { visit_rs::Visit },
255 Some(&syn::parse_quote! { visit_rs::Covered }),
256 false,
257 false,
258 );
259
260 let visit_fields_impl = field_idx_iter(&data.fields).enumerate().map(|(num, idx)| {
261 quote! {
262 #num => {
263 pos += 1;
264 Some(visit_rs::Visit::visit(&visit_rs::Covered(&self.#idx), visitor))
265 }
266 }
267 });
268
269 Ok(quote! {
270 #impl_t {
271 fn visit_fields_covered<'__visit_rs__a>(
272 &'__visit_rs__a self,
273 visitor: &'__visit_rs__a mut __visit_rs__V
274 ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> {
275 std::iter::from_fn({
276 let mut pos = 0;
277 move || match pos {
278 #(#visit_fields_impl)*
279 _ => None,
280 }
281 })
282 }
283 }
284 })
285}
286
287fn derive_visit_fields_async(
288 ast: &DeriveInput,
289 data: &DataStruct,
290) -> Result<TokenStream, syn::Error> {
291 let impl_t = make_impl(
292 &ast,
293 &data.fields,
294 &syn::parse_quote! { visit_rs::VisitFieldsAsync },
295 &syn::parse_quote! { visit_rs::VisitAsync },
296 None,
297 true,
298 false,
299 );
300
301 let visit_fields_impl = field_idx_iter(&data.fields).map(|idx| {
302 quote! {
303 yield visit_rs::VisitAsync::visit_async(&self.#idx, visitor).await;
304 }
305 });
306
307 Ok(quote! {
308 #impl_t {
309 fn visit_fields_async<'__visit_rs__a>(
310 &'__visit_rs__a self,
311 visitor: &'__visit_rs__a mut __visit_rs__V,
312 ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
313 where
314 __visit_rs__V: Send,
315 <__visit_rs__V as visit_rs::Visitor>::Result: Send,
316 {
317 visit_rs::lib::async_stream::stream! {
318 #(#visit_fields_impl)*
319 #[allow(unreachable_code)]
320 if false {
321 yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
322 }
323 }
324 }
325 }
326 })
327}
328
329fn derive_visit_fields_covered_async(
330 ast: &DeriveInput,
331 data: &DataStruct,
332) -> Result<TokenStream, syn::Error> {
333 let impl_t = make_impl(
334 &ast,
335 &data.fields,
336 &syn::parse_quote! { visit_rs::VisitFieldsCoveredAsync },
337 &syn::parse_quote! { visit_rs::VisitAsync },
338 Some(&syn::parse_quote! { visit_rs::Covered }),
339 true,
340 false,
341 );
342
343 let visit_fields_impl = field_idx_iter(&data.fields).map(|idx| {
344 quote! {
345 yield visit_rs::VisitAsync::visit_async(&visit_rs::Covered(&self.#idx), visitor).await;
346 }
347 });
348
349 Ok(quote! {
350 #impl_t {
351 fn visit_fields_covered_async<'__visit_rs__a>(
352 &'__visit_rs__a self,
353 visitor: &'__visit_rs__a mut __visit_rs__V,
354 ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
355 where
356 __visit_rs__V: Send,
357 <__visit_rs__V as visit_rs::Visitor>::Result: Send,
358 {
359 visit_rs::lib::async_stream::stream! {
360 #(#visit_fields_impl)*
361 #[allow(unreachable_code)]
362 if false {
363 yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
364 }
365 }
366 }
367 }
368 })
369}
370
371fn derive_visit_fields_named(
372 ast: &DeriveInput,
373 data: &DataStruct,
374) -> Result<TokenStream, syn::Error> {
375 let impl_t = make_impl(
376 &ast,
377 &data.fields,
378 &syn::parse_quote! { visit_rs::VisitFieldsNamed },
379 &syn::parse_quote! { visit_rs::Visit },
380 Some(&syn::parse_quote! { visit_rs::Named }),
381 false,
382 false,
383 );
384
385 let visit_fields_named_impl =
386 field_name_idx_iter(&data.fields)
387 .enumerate()
388 .map(|(num, (name, idx))| {
389 quote! {
390 #num => {
391 pos += 1;
392 Some(visit_rs::Visit::visit(&visit_rs::Named {
393 name: #name,
394 value: &self.#idx,
395 }, visitor))
396 }
397 }
398 });
399
400 Ok(quote! {
401 #impl_t {
402 fn visit_fields_named<'__visit_rs__a>(
403 &'__visit_rs__a self,
404 visitor: &'__visit_rs__a mut __visit_rs__V
405 ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
406 std::iter::from_fn({
407 let mut pos = 0;
408 move || match pos {
409 #(#visit_fields_named_impl)*
410 _ => None,
411 }
412 })
413 }
414 }
415 })
416}
417
418fn derive_visit_fields_named_async(
419 ast: &DeriveInput,
420 data: &DataStruct,
421) -> Result<TokenStream, syn::Error> {
422 let impl_t = make_impl(
423 &ast,
424 &data.fields,
425 &syn::parse_quote! { visit_rs::VisitFieldsNamedAsync },
426 &syn::parse_quote! { visit_rs::VisitAsync },
427 Some(&syn::parse_quote! { visit_rs::Named }),
428 true,
429 false,
430 );
431
432 let visit_fields_named_impl = field_name_idx_iter(&data.fields).map(|(name, idx)| {
433 quote! {
434 yield visit_rs::VisitAsync::visit_async(&visit_rs::Named {
435 name: #name,
436 value: &self.#idx,
437 }, visitor).await;
438 }
439 });
440
441 Ok(quote! {
442 #impl_t {
443 fn visit_fields_named_async<'__visit_rs__a>(
444 &'__visit_rs__a self,
445 visitor: &'__visit_rs__a mut __visit_rs__V,
446 ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
447 where
448 __visit_rs__V: Send,
449 <__visit_rs__V as visit_rs::Visitor>::Result: Send,
450 {
451 visit_rs::lib::async_stream::stream! {
452 #(#visit_fields_named_impl)*
453 #[allow(unreachable_code)]
454 if false {
455 yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
456 }
457 }
458 }
459 }
460 })
461}
462
463fn derive_visit_fields_static(
464 ast: &DeriveInput,
465 data: &DataStruct,
466) -> Result<TokenStream, syn::Error> {
467 let impl_t = make_impl(
468 &ast,
469 &data.fields,
470 &syn::parse_quote! { visit_rs::VisitFieldsStatic },
471 &syn::parse_quote! { visit_rs::Visit },
472 None,
473 false,
474 true,
475 );
476
477 let field_types: Vec<_> = field_iter(&data.fields)
478 .map(|(_, field)| &field.ty)
479 .collect();
480 let visit_fields_impl = field_types.iter().enumerate().map(|(num, ty)| {
481 quote! {
482 #num => {
483 pos += 1;
484 Some(visit_rs::Visit::visit(&visit_rs::Static::<#ty>::new(), visitor))
485 }
486 }
487 });
488
489 Ok(quote! {
490 #impl_t {
491 fn visit_fields_static<'__visit_rs__a>(
492 visitor: &'__visit_rs__a mut __visit_rs__V
493 ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
494 std::iter::from_fn({
495 let mut pos = 0;
496 move || match pos {
497 #(#visit_fields_impl)*
498 _ => None,
499 }
500 })
501 }
502 }
503 })
504}
505
506fn derive_visit_fields_static_async(
507 ast: &DeriveInput,
508 data: &DataStruct,
509) -> Result<TokenStream, syn::Error> {
510 let impl_t = make_impl(
511 &ast,
512 &data.fields,
513 &syn::parse_quote! { visit_rs::VisitFieldsStaticAsync },
514 &syn::parse_quote! { visit_rs::VisitAsync },
515 None,
516 true,
517 true,
518 );
519
520 let field_types: Vec<_> = field_iter(&data.fields)
521 .map(|(_, field)| &field.ty)
522 .collect();
523 let visit_fields_impl = field_types.iter().map(|ty| {
524 quote! {
525 yield visit_rs::VisitAsync::visit_async(&visit_rs::Static::<#ty>::new(), visitor).await;
526 }
527 });
528
529 Ok(quote! {
530 #impl_t {
531 fn visit_fields_static_async<'__visit_rs__a>(
532 visitor: &'__visit_rs__a mut __visit_rs__V,
533 ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
534 where
535 __visit_rs__V: Send,
536 <__visit_rs__V as visit_rs::Visitor>::Result: Send,
537 {
538 visit_rs::lib::async_stream::stream! {
539 #(#visit_fields_impl)*
540 #[allow(unreachable_code)]
541 if false {
542 yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
543 }
544 }
545 }
546 }
547 })
548}
549
550fn derive_visit_fields_static_named(
551 ast: &DeriveInput,
552 data: &DataStruct,
553) -> Result<TokenStream, syn::Error> {
554 let impl_t = make_impl(
555 &ast,
556 &data.fields,
557 &syn::parse_quote! { visit_rs::VisitFieldsStaticNamed },
558 &syn::parse_quote! { visit_rs::Visit },
559 Some(&syn::parse_quote! { visit_rs::Named }),
560 false,
561 true,
562 );
563
564 let field_name_type_iter = field_iter(&data.fields).map(|(_, field)| {
565 let field_name = &field.ident;
566 let ty = &field.ty;
567 let name = if let Some(name) = field_name {
568 quote! { Some(stringify!(#name)) }
569 } else {
570 quote! { None }
571 };
572 (name, ty)
573 });
574
575 let visit_fields_named_impl = field_name_type_iter.enumerate().map(|(num, (name, ty))| {
576 quote! {
577 #num => {
578 pos += 1;
579 {
580 static __VISIT_RS_STATIC: visit_rs::Static<()> = visit_rs::Static::new();
581 let named = visit_rs::Named {
582 name: #name,
583 value: unsafe {
584 &*(&__VISIT_RS_STATIC as *const visit_rs::Static<()> as *const visit_rs::Static<#ty>)
587 },
588 };
589 Some(visit_rs::Visit::visit(&named, visitor))
590 }
591 }
592 }
593 });
594
595 Ok(quote! {
596 #impl_t {
597 fn visit_fields_static_named<'__visit_rs__a>(
598 visitor: &'__visit_rs__a mut __visit_rs__V
599 ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
600 std::iter::from_fn({
601 let mut pos = 0;
602 move || match pos {
603 #(#visit_fields_named_impl)*
604 _ => None,
605 }
606 })
607 }
608 }
609 })
610}
611
612fn derive_visit_fields_static_named_async(
613 ast: &DeriveInput,
614 data: &DataStruct,
615) -> Result<TokenStream, syn::Error> {
616 let impl_t = make_impl(
617 &ast,
618 &data.fields,
619 &syn::parse_quote! { visit_rs::VisitFieldsStaticNamedAsync },
620 &syn::parse_quote! { visit_rs::VisitAsync },
621 Some(&syn::parse_quote! { visit_rs::Named }),
622 true,
623 true,
624 );
625
626 let field_name_type_iter = field_iter(&data.fields).map(|(_, field)| {
627 let field_name = &field.ident;
628 let ty = &field.ty;
629 let name = if let Some(name) = field_name {
630 quote! { Some(stringify!(#name)) }
631 } else {
632 quote! { None }
633 };
634 (name, ty)
635 });
636
637 let visit_fields_named_impl = field_name_type_iter.map(|(name, ty)| {
638 quote! {
639 {
640 static __VISIT_RS_STATIC: visit_rs::Static<()> = visit_rs::Static::new();
641 let named = visit_rs::Named {
642 name: #name,
643 value: unsafe {
644 &*(&__VISIT_RS_STATIC as *const visit_rs::Static<()> as *const visit_rs::Static<#ty>)
647 },
648 };
649 yield visit_rs::VisitAsync::visit_async(&named, visitor).await;
650 }
651 }
652 });
653
654 Ok(quote! {
655 #impl_t {
656 fn visit_fields_static_named_async<'__visit_rs__a>(
657 visitor: &'__visit_rs__a mut __visit_rs__V,
658 ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
659 where
660 __visit_rs__V: Send,
661 <__visit_rs__V as visit_rs::Visitor>::Result: Send,
662 {
663 visit_rs::lib::async_stream::stream! {
664 #(#visit_fields_named_impl)*
665 #[allow(unreachable_code)]
666 if false {
667 yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
668 }
669 }
670 }
671 }
672 })
673}
674
675mod helpers;
676mod enum_variants;
677
678#[proc_macro_derive(VisitVariants, attributes(visit))]
679pub fn derive_visit_variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
680 let ast: DeriveInput = syn::parse(input).unwrap();
681
682 let syn::Data::Enum(data) = &ast.data else {
683 return syn::Error::new_spanned(&ast.ident, "VisitVariants can only be used on enums")
684 .to_compile_error()
685 .into();
686 };
687
688 match enum_variants::derive_all_variant_traits(&ast, data) {
689 Ok(tokens) => tokens.into(),
690 Err(e) => e.to_compile_error().into(),
691 }
692}