1use proc_macro2::{Group, TokenStream, TokenTree};
2use quote::{TokenStreamExt, quote};
3use syn::{
4 Attribute, Block, FnArg, Generics, Ident, Pat, Stmt, Token, TraitItem, Visibility, braced,
5 parse::Parse,
6};
7
8pub struct Input {
9 attrs: Vec<Attribute>,
10 vis: Visibility,
11 trait_token: Token![trait],
12 name: Ident,
13 mut_name: Option<Ident>,
14 generics: Generics,
15 items: TokenStream,
16}
17
18impl Parse for Input {
19 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
20 let attrs = input.call(Attribute::parse_outer)?;
21 let vis = input.parse()?;
22 let trait_token = input.parse()?;
23 let name: Ident = input.parse()?;
24 let mut_name: Option<Ident> = input.parse()?;
25 let generics: Generics = input.parse()?;
26
27 let content;
28 braced!(content in input);
29 let items = content.parse()?;
30
31 Ok(Self { attrs, vis, trait_token, name, mut_name, generics, items })
32 }
33}
34
35impl Input {
36 pub fn expand(&self) -> TokenStream {
37 let Self { attrs, vis, trait_token, name, mut_name, generics, items } = self;
38
39 let expand = |nonmut_items: TokenStream, mut_items: Option<TokenStream>| {
40 let mut_trait = mut_items.map(|mut_items| {
41 quote! {
42 #(#attrs)*
43 #vis #trait_token #mut_name #generics {
44 #mut_items
45 }
46 }
47 });
48 quote! {
49 #(#attrs)*
50 #vis #trait_token #name #generics {
51 #nonmut_items
52 }
53
54 #mut_trait
55 }
56 };
57
58 let (nonmut_items, mut_items) = expand_streams(items);
59 let fallback = || expand(nonmut_items.clone(), None);
61 let Ok(mut nonmut_trait_items) = parse_trait_items(nonmut_items.clone()) else {
62 return fallback();
63 };
64 let Ok(mut mut_trait_items) = parse_trait_items(mut_items) else {
65 return fallback();
66 };
67
68 for item in &mut mut_trait_items {
69 if let TraitItem::Fn(f) = item {
70 f.sig.ident = Ident::new(&format!("{}_mut", f.sig.ident), f.sig.ident.span());
71 }
72 }
73
74 add_walk_fns(&mut mut_trait_items);
75 add_walk_fns(&mut nonmut_trait_items);
76
77 expand(
78 quote! { #(#nonmut_trait_items)* },
79 mut_name.is_some().then(|| quote! { #(#mut_trait_items)* }),
80 )
81 }
82}
83
84fn expand_streams(tts: &TokenStream) -> (TokenStream, TokenStream) {
87 let mut nonmut_tts = TokenStream::new();
88 let mut mut_tts = TokenStream::new();
89 let mut tt_iter = tts.clone().into_iter();
90 while let Some(tt) = tt_iter.next() {
91 match tt {
92 TokenTree::Group(group) => {
93 let (nm, m) = expand_streams(&group.stream());
94 let group = |stream| {
95 let mut g = Group::new(group.delimiter(), stream);
96 g.set_span(group.span());
97 g
98 };
99 nonmut_tts.append(group(nm));
100 mut_tts.append(group(m));
101 }
102 TokenTree::Punct(punct)
103 if punct.as_char() == '#' && tt_iter.clone().next().is_some_and(is_token_mut) =>
104 {
105 let mut_token = tt_iter.next().unwrap();
106 mut_tts.append(mut_token);
107 }
108 TokenTree::Punct(punct)
109 if punct.as_char() == '#'
110 && tt_iter.clone().next().is_some_and(is_token_onlymut) =>
111 {
112 let _onlymut_token = tt_iter.next().unwrap();
113 let TokenTree::Group(group) = tt_iter.next().unwrap() else { continue };
114 mut_tts.extend(group.stream());
115 }
116 TokenTree::Ident(id)
117 if tt_iter.clone().next().is_some_and(is_token_hash)
118 && tt_iter.clone().nth(1).is_some_and(is_token_underscore_mut) =>
119 {
120 let _ = tt_iter.next();
121 let _ = tt_iter.next();
122 mut_tts.append(Ident::new(&format!("{id}_mut"), id.span()));
123 nonmut_tts.append(id);
124 }
125 tt => {
126 nonmut_tts.append(tt.clone());
127 mut_tts.append(tt);
128 }
129 }
130 }
131 (nonmut_tts, mut_tts)
132}
133
134fn is_token_hash(tt: TokenTree) -> bool {
135 if let TokenTree::Punct(punct) = tt {
136 return punct.as_char() == '#';
137 }
138 false
139}
140
141fn is_token_mut(tt: TokenTree) -> bool {
142 if let TokenTree::Ident(ident) = tt {
143 return ident == "mut";
144 }
145 false
146}
147
148fn is_token_onlymut(tt: TokenTree) -> bool {
149 if let TokenTree::Ident(ident) = tt {
150 return ident == "onlymut";
151 }
152 false
153}
154
155fn is_token_underscore_mut(tt: TokenTree) -> bool {
156 if let TokenTree::Ident(ident) = tt {
157 return ident == "_mut";
158 }
159 false
160}
161
162fn parse_trait_items(tts: TokenStream) -> Result<Vec<TraitItem>, syn::Error> {
163 struct TraitItems(Vec<TraitItem>);
164 impl Parse for TraitItems {
165 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
166 let mut items = vec![];
167 while !input.is_empty() {
168 items.push(input.parse()?);
169 }
170 Ok(Self(items))
171 }
172 }
173 Ok(syn::parse2::<TraitItems>(tts)?.0)
174}
175
176fn add_walk_fns(items: &mut Vec<TraitItem>) {
179 for i in 0..items.len() {
180 let item = &mut items[i];
181 if let TraitItem::Fn(f) = item {
182 let name = f.sig.ident.to_string();
183 let Some(name) = name.strip_prefix("visit_") else { continue };
184 let walk_name = Ident::new(&format!("walk_{name}"), f.sig.ident.span());
185
186 let mut walk_fn = f.clone();
187 let Some(body) = &mut f.default else { continue };
188 f.attrs.push(syn::parse_quote!(#[inline]));
189
190 let args = f.sig.inputs.iter().filter_map(|arg| {
191 Some(match arg {
192 FnArg::Receiver(_rec) => return None,
193 FnArg::Typed(pat) => match &*pat.pat {
194 Pat::Ident(ident) => {
195 let id = &ident.ident;
196 quote!(#id)
197 }
198 _ => return None,
199 },
200 })
201 });
202 let call_walk = syn::parse_quote! {
203 self.#walk_name(#(#args),*)
204 };
205 let call_walk_stmt = Stmt::Expr(call_walk, None);
206 let walk_stmts = std::mem::replace(&mut body.stmts, vec![call_walk_stmt]);
207
208 walk_fn.sig.ident = walk_name;
209 walk_fn.default = Some(Block { brace_token: body.brace_token, stmts: walk_stmts });
210 items.push(TraitItem::Fn(walk_fn));
211 }
212 }
213}