1use proc_macro2::TokenStream;
19use quote::{format_ident, quote, quote_spanned, ToTokens};
20use syn::spanned::Spanned;
21use syn::{
22 parse::{Parse, ParseStream},
23 parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
24 Ident, Index, LitStr, Meta, Token, Type, TypePath,
25};
26use syn::{Path, PathArguments};
27
28#[proc_macro_derive(VisitMut, attributes(visit))]
30pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 derive_visit(
32 input,
33 &VisitType {
34 visit_trait: quote!(VisitMut),
35 visitor_trait: quote!(VisitorMut),
36 modifier: Some(quote!(mut)),
37 },
38 )
39}
40
41#[proc_macro_derive(Visit, attributes(visit))]
43pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
44 derive_visit(
45 input,
46 &VisitType {
47 visit_trait: quote!(Visit),
48 visitor_trait: quote!(Visitor),
49 modifier: None,
50 },
51 )
52}
53
54struct VisitType {
55 visit_trait: TokenStream,
56 visitor_trait: TokenStream,
57 modifier: Option<TokenStream>,
58}
59
60fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_macro::TokenStream {
61 let input = parse_macro_input!(input as DeriveInput);
63 let name = input.ident;
64
65 let VisitType {
66 visit_trait,
67 visitor_trait,
68 modifier,
69 } = visit_type;
70
71 let attributes = Attributes::parse(&input.attrs);
72 let generics = add_trait_bounds(input.generics, visit_type);
74 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
75
76 let (pre_visit, post_visit) = attributes.visit(quote!(self));
77 let children = visit_children(&input.data, visit_type);
78
79 let expanded = quote! {
80 impl #impl_generics sqltk_parser::ast::#visit_trait for #name #ty_generics #where_clause {
82 fn visit<V: sqltk_parser::ast::#visitor_trait>(
83 &#modifier self,
84 visitor: &mut V
85 ) -> ::std::ops::ControlFlow<V::Break> {
86 #pre_visit
87 #children
88 #post_visit
89 ::std::ops::ControlFlow::Continue(())
90 }
91 }
92 };
93
94 proc_macro::TokenStream::from(expanded)
95}
96
97#[derive(Default)]
101struct Attributes {
102 with: Option<Ident>,
104}
105
106struct WithIdent {
107 with: Option<Ident>,
108}
109impl Parse for WithIdent {
110 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
111 let mut result = WithIdent { with: None };
112 let ident = input.parse::<Ident>()?;
113 if ident != "with" {
114 return Err(syn::Error::new(
115 ident.span(),
116 "Expected identifier to be `with`",
117 ));
118 }
119 input.parse::<Token!(=)>()?;
120 let s = input.parse::<LitStr>()?;
121 result.with = Some(format_ident!("{}", s.value(), span = s.span()));
122 Ok(result)
123 }
124}
125
126impl Attributes {
127 fn parse(attrs: &[Attribute]) -> Self {
128 let mut out = Self::default();
129 for attr in attrs {
130 if let Meta::List(ref metalist) = attr.meta {
131 if metalist.path.is_ident("visit") {
132 match syn::parse2::<WithIdent>(metalist.tokens.clone()) {
133 Ok(with_ident) => {
134 out.with = with_ident.with;
135 }
136 Err(e) => {
137 panic!("{}", e);
138 }
139 }
140 }
141 }
142 }
143 out
144 }
145
146 fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
148 let pre_visit = self.with.as_ref().map(|m| {
149 let m = format_ident!("pre_{}", m);
150 quote!(visitor.#m(#s)?;)
151 });
152 let post_visit = self.with.as_ref().map(|m| {
153 let m = format_ident!("post_{}", m);
154 quote!(visitor.#m(#s)?;)
155 });
156 (pre_visit, post_visit)
157 }
158}
159
160fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics {
162 for param in &mut generics.params {
163 if let GenericParam::Type(ref mut type_param) = *param {
164 type_param
165 .bounds
166 .push(parse_quote!(sqltk_parser::ast::#visit_trait));
167 }
168 }
169 generics
170}
171
172fn visit_children(
174 data: &Data,
175 VisitType {
176 visit_trait,
177 modifier,
178 ..
179 }: &VisitType,
180) -> TokenStream {
181 match data {
182 Data::Struct(data) => match &data.fields {
183 Fields::Named(fields) => {
184 let recurse = fields.named.iter().map(|f| {
185 let name = &f.ident;
186 let is_option = is_option(&f.ty);
187 let attributes = Attributes::parse(&f.attrs);
188 if is_option && attributes.with.is_some() {
189 let (pre_visit, post_visit) = attributes.visit(quote!(value));
190 quote_spanned!(f.span() =>
191 if let Some(value) = &#modifier self.#name {
192 #pre_visit sqltk_parser::ast::#visit_trait::visit(value, visitor)?; #post_visit
193 }
194 )
195 } else {
196 let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
197 quote_spanned!(f.span() =>
198 #pre_visit sqltk_parser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
199 )
200 }
201 });
202 quote! {
203 #(#recurse)*
204 }
205 }
206 Fields::Unnamed(fields) => {
207 let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
208 let index = Index::from(i);
209 let attributes = Attributes::parse(&f.attrs);
210 let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
211 quote_spanned!(f.span() => #pre_visit sqltk_parser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
212 });
213 quote! {
214 #(#recurse)*
215 }
216 }
217 Fields::Unit => {
218 quote!()
219 }
220 },
221 Data::Enum(data) => {
222 let statements = data.variants.iter().map(|v| {
223 let name = &v.ident;
224 match &v.fields {
225 Fields::Named(fields) => {
226 let names = fields.named.iter().map(|f| &f.ident);
227 let visit = fields.named.iter().map(|f| {
228 let name = &f.ident;
229 let attributes = Attributes::parse(&f.attrs);
230 let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
231 quote_spanned!(f.span() => #pre_visit sqltk_parser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
232 });
233
234 quote!(
235 Self::#name { #(#names),* } => {
236 #(#visit)*
237 }
238 )
239 }
240 Fields::Unnamed(fields) => {
241 let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
242 let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
243 let name = format_ident!("_{}", i);
244 let attributes = Attributes::parse(&f.attrs);
245 let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
246 quote_spanned!(f.span() => #pre_visit sqltk_parser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
247 });
248
249 quote! {
250 Self::#name ( #(#names),*) => {
251 #(#visit)*
252 }
253 }
254 }
255 Fields::Unit => {
256 quote! {
257 Self::#name => {}
258 }
259 }
260 }
261 });
262
263 quote! {
264 match self {
265 #(#statements),*
266 }
267 }
268 }
269 Data::Union(_) => unimplemented!(),
270 }
271}
272
273fn is_option(ty: &Type) -> bool {
274 if let Type::Path(TypePath {
275 path: Path { segments, .. },
276 ..
277 }) = ty
278 {
279 if let Some(segment) = segments.last() {
280 if segment.ident == "Option" {
281 if let PathArguments::AngleBracketed(args) = &segment.arguments {
282 return args.args.len() == 1;
283 }
284 }
285 }
286 }
287 false
288}