yachtsql_sqlparser_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use proc_macro2::TokenStream;
19use quote::{format_ident, quote, quote_spanned, ToTokens};
20use syn::spanned::Spanned;
21use syn::{
22    parse::{Parse, ParseStream},
23    parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
24    Ident, Index, LitStr, Meta, Token, Type, TypePath,
25};
26use syn::{Path, PathArguments};
27
28/// Implementation of `[#derive(Visit)]`
29#[proc_macro_derive(VisitMut, attributes(visit))]
30pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31    derive_visit(
32        input,
33        &VisitType {
34            visit_trait: quote!(VisitMut),
35            visitor_trait: quote!(VisitorMut),
36            modifier: Some(quote!(mut)),
37        },
38    )
39}
40
41/// Implementation of `[#derive(Visit)]`
42#[proc_macro_derive(Visit, attributes(visit))]
43pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
44    derive_visit(
45        input,
46        &VisitType {
47            visit_trait: quote!(Visit),
48            visitor_trait: quote!(Visitor),
49            modifier: None,
50        },
51    )
52}
53
54struct VisitType {
55    visit_trait: TokenStream,
56    visitor_trait: TokenStream,
57    modifier: Option<TokenStream>,
58}
59
60fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_macro::TokenStream {
61    // Parse the input tokens into a syntax tree.
62    let input = parse_macro_input!(input as DeriveInput);
63    let name = input.ident;
64
65    let VisitType {
66        visit_trait,
67        visitor_trait,
68        modifier,
69    } = visit_type;
70
71    let attributes = Attributes::parse(&input.attrs);
72    // Add a bound `T: Visit` to every type parameter T.
73    let generics = add_trait_bounds(input.generics, visit_type);
74    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
75
76    let (pre_visit, post_visit) = attributes.visit(quote!(self));
77    let children = visit_children(&input.data, visit_type);
78
79    let expanded = quote! {
80        // The generated impl.
81        // Note that it uses [`recursive::recursive`] to protect from stack overflow.
82        // See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
83        impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
84             #[cfg_attr(feature = "recursive-protection", recursive::recursive)]
85            fn visit<V: sqlparser::ast::#visitor_trait>(
86                &#modifier self,
87                visitor: &mut V
88            ) -> ::std::ops::ControlFlow<V::Break> {
89                #pre_visit
90                #children
91                #post_visit
92                ::std::ops::ControlFlow::Continue(())
93            }
94        }
95    };
96
97    proc_macro::TokenStream::from(expanded)
98}
99
100/// Parses attributes that can be provided to this macro
101///
102/// `#[visit(leaf, with = "visit_expr")]`
103#[derive(Default)]
104struct Attributes {
105    /// Content for the `with` attribute
106    with: Option<Ident>,
107}
108
109struct WithIdent {
110    with: Option<Ident>,
111}
112impl Parse for WithIdent {
113    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
114        let mut result = WithIdent { with: None };
115        let ident = input.parse::<Ident>()?;
116        if ident != "with" {
117            return Err(syn::Error::new(
118                ident.span(),
119                "Expected identifier to be `with`",
120            ));
121        }
122        input.parse::<Token!(=)>()?;
123        let s = input.parse::<LitStr>()?;
124        result.with = Some(format_ident!("{}", s.value(), span = s.span()));
125        Ok(result)
126    }
127}
128
129impl Attributes {
130    fn parse(attrs: &[Attribute]) -> Self {
131        let mut out = Self::default();
132        for attr in attrs {
133            if let Meta::List(ref metalist) = attr.meta {
134                if metalist.path.is_ident("visit") {
135                    match syn::parse2::<WithIdent>(metalist.tokens.clone()) {
136                        Ok(with_ident) => {
137                            out.with = with_ident.with;
138                        }
139                        Err(e) => {
140                            panic!("{}", e);
141                        }
142                    }
143                }
144            }
145        }
146        out
147    }
148
149    /// Returns the pre and post visit token streams
150    fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
151        let pre_visit = self.with.as_ref().map(|m| {
152            let m = format_ident!("pre_{}", m);
153            quote!(visitor.#m(#s)?;)
154        });
155        let post_visit = self.with.as_ref().map(|m| {
156            let m = format_ident!("post_{}", m);
157            quote!(visitor.#m(#s)?;)
158        });
159        (pre_visit, post_visit)
160    }
161}
162
163// Add a bound `T: Visit` to every type parameter T.
164fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics {
165    for param in &mut generics.params {
166        if let GenericParam::Type(ref mut type_param) = *param {
167            type_param
168                .bounds
169                .push(parse_quote!(sqlparser::ast::#visit_trait));
170        }
171    }
172    generics
173}
174
175// Generate the body of the visit implementation for the given type
176fn visit_children(
177    data: &Data,
178    VisitType {
179        visit_trait,
180        modifier,
181        ..
182    }: &VisitType,
183) -> TokenStream {
184    match data {
185        Data::Struct(data) => match &data.fields {
186            Fields::Named(fields) => {
187                let recurse = fields.named.iter().map(|f| {
188                    let name = &f.ident;
189                    let is_option = is_option(&f.ty);
190                    let attributes = Attributes::parse(&f.attrs);
191                    if is_option && attributes.with.is_some() {
192                        let (pre_visit, post_visit) = attributes.visit(quote!(value));
193                        quote_spanned!(f.span() =>
194                            if let Some(value) = &#modifier self.#name {
195                                #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
196                            }
197                        )
198                    } else {
199                        let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
200                        quote_spanned!(f.span() =>
201                            #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
202                        )
203                    }
204                });
205                quote! {
206                    #(#recurse)*
207                }
208            }
209            Fields::Unnamed(fields) => {
210                let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
211                    let index = Index::from(i);
212                    let attributes = Attributes::parse(&f.attrs);
213                    let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
214                    quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
215                });
216                quote! {
217                    #(#recurse)*
218                }
219            }
220            Fields::Unit => {
221                quote!()
222            }
223        },
224        Data::Enum(data) => {
225            let statements = data.variants.iter().map(|v| {
226                let name = &v.ident;
227                match &v.fields {
228                    Fields::Named(fields) => {
229                        let names = fields.named.iter().map(|f| &f.ident);
230                        let visit = fields.named.iter().map(|f| {
231                            let name = &f.ident;
232                            let attributes = Attributes::parse(&f.attrs);
233                            let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
234                            quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
235                        });
236
237                        quote!(
238                            Self::#name { #(#names),* } => {
239                                #(#visit)*
240                            }
241                        )
242                    }
243                    Fields::Unnamed(fields) => {
244                        let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
245                        let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
246                            let name = format_ident!("_{}", i);
247                            let attributes = Attributes::parse(&f.attrs);
248                            let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
249                            quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
250                        });
251
252                        quote! {
253                            Self::#name ( #(#names),*) => {
254                                #(#visit)*
255                            }
256                        }
257                    }
258                    Fields::Unit => {
259                        quote! {
260                            Self::#name => {}
261                        }
262                    }
263                }
264            });
265
266            quote! {
267                match self {
268                    #(#statements),*
269                }
270            }
271        }
272        Data::Union(_) => unimplemented!(),
273    }
274}
275
276fn is_option(ty: &Type) -> bool {
277    if let Type::Path(TypePath {
278        path: Path { segments, .. },
279        ..
280    }) = ty
281    {
282        if let Some(segment) = segments.last() {
283            if segment.ident == "Option" {
284                if let PathArguments::AngleBracketed(args) = &segment.arguments {
285                    return args.args.len() == 1;
286                }
287            }
288        }
289    }
290    false
291}