substrait_validator_derive/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Procedural macro crate for `substrait-validator-core`.
4//!
5//! The derive macros defined here are essentially an ugly workaround for the
6//! lack of any protobuf introspection functionality provided by prost.
7//! Basically, they take (the AST of) the code generated by prost and try to
8//! recover the needed protobuf message metadata from there. Things would have
9//! been a *LOT* simpler and a *LOT* less brittle if prost would simply
10//! provide this information via traits of its own, but alas, there doesn't
11//! seem to be a way to do this without forking prost, and introspection
12//! seems to be a non-goal of that project.
13//!
14//! Besides being ugly, this method is rather brittle and imprecise when it
15//! comes to recovering field names, due to the various case conversions
16//! automatically done by protoc and prost. Some known issues are:
17//!
18//!  - The recovered type name for messages defined within messages uses
19//!    incorrect case conventions, as the procedural macros have no way
20//!    of distinguishing packages from message definition scopes in the
21//!    type path.
22//!  - If the .proto source files use unexpected case conventions for
23//!    various things, the resulting case conventions for types, field names,
24//!    oneof variants, and enum variants will be wrong.
25//!  - Whenever the .proto source files name a field using something that is
26//!    a reserved word in Rust (notably `type`), prost will use a raw
27//!    identifier to represent the name. This syntax is currently not filtered
28//!    out from the recovered names, so a field named `type` becomes `r#type`.
29//!    This is probably not a fundamental problem, though.
30//!
31//! Ultimately, however, these names are only used for diagnostic messages and
32//! the likes. In the worst case, the above inconsistencies may confuse the
33//! user, but they should not affect the valid/invalid/maybe-valid result of
34//! the validator or cause compile- or runtime errors.
35
36extern crate proc_macro;
37
38use heck::{ToShoutySnakeCase, ToSnakeCase};
39use proc_macro::TokenStream;
40use quote::quote;
41
42/// Converts a Rust identifier string generated via stringify!() to the
43/// original identifier by "cooking" raw identifiers.
44fn cook_ident(ident: &syn::Ident) -> String {
45    let ident = ident.to_string();
46    if let Some((_, keyword)) = ident.split_once('#') {
47        keyword.to_string()
48    } else {
49        ident
50    }
51}
52
53#[doc(hidden)]
54#[proc_macro_derive(ProtoMeta, attributes(proto_meta))]
55pub fn proto_meta(input: TokenStream) -> TokenStream {
56    proto_meta_derive(syn::parse_macro_input!(input))
57}
58
59fn proto_meta_derive(ast: syn::DeriveInput) -> TokenStream {
60    match ast.data {
61        syn::Data::Struct(ref struct_data) => proto_meta_derive_message(&ast, struct_data),
62        syn::Data::Enum(ref enum_data) => match enum_data.variants.iter().next().unwrap().fields {
63            syn::Fields::Unit => {
64                for variant in enum_data.variants.iter() {
65                    if !matches!(variant.fields, syn::Fields::Unit) {
66                        panic!("all variants of a protobuf oneof enum must have a single, unnamed field");
67                    }
68                }
69
70                proto_meta_derive_enum(&ast, enum_data)
71            }
72            syn::Fields::Unnamed(..) => {
73                for variant in enum_data.variants.iter() {
74                    if let syn::Fields::Unnamed(fields) = &variant.fields {
75                        if fields.unnamed.len() != 1 {
76                            panic!("all variants of a protobuf oneof enum must have a single, unnamed field");
77                        }
78                    } else {
79                        panic!("all variants of a protobuf oneof enum must have a single, unnamed field");
80                    }
81                }
82
83                proto_meta_derive_oneof(&ast, enum_data)
84            }
85            _ => panic!("enum with named elements don't map to protobuf constructs"),
86        },
87        syn::Data::Union(_) => panic!("unions don't map to protobuf constructs"),
88    }
89}
90
91enum FieldType {
92    Optional,
93    BoxedOptional,
94    Repeated,
95    Primitive,
96}
97
98fn is_repeated(typ: &syn::Type) -> FieldType {
99    if let syn::Type::Path(path) = typ {
100        if let Some(last) = path.path.segments.last() {
101            if last.ident == "Option" {
102                if let syn::PathArguments::AngleBracketed(ref args) = last.arguments {
103                    if let syn::GenericArgument::Type(syn::Type::Path(path2)) =
104                        args.args.first().unwrap()
105                    {
106                        if path2.path.segments.last().unwrap().ident == "Box" {
107                            return FieldType::BoxedOptional;
108                        } else {
109                            return FieldType::Optional;
110                        }
111                    }
112                }
113                panic!("Option without type argument?");
114            } else if last.ident == "Vec" {
115                if let syn::PathArguments::AngleBracketed(ref args) = last.arguments {
116                    if let syn::GenericArgument::Type(syn::Type::Path(path2)) =
117                        args.args.first().unwrap()
118                    {
119                        if path2.path.segments.last().unwrap().ident == "u8" {
120                            return FieldType::Primitive;
121                        } else {
122                            return FieldType::Repeated;
123                        }
124                    }
125                }
126                panic!("Vec without type argument?");
127            }
128        }
129    }
130    FieldType::Primitive
131}
132
133fn proto_meta_derive_message(ast: &syn::DeriveInput, data: &syn::DataStruct) -> TokenStream {
134    let name = &ast.ident;
135    let name_str = cook_ident(name);
136    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
137
138    let parse_unknown_matches: Vec<_> = data
139        .fields
140        .iter()
141        .map(|field| {
142            if let Some(ident) = &field.ident {
143                let ident_str = cook_ident(ident);
144                let action = match is_repeated(&field.ty) {
145                    FieldType::Optional => quote! {
146                        crate::parse::traversal::push_proto_field(
147                            y,
148                            &self.#ident.as_ref(),
149                            #ident_str,
150                            true,
151                            |_, _| Ok(()),
152                        );
153                    },
154                    FieldType::BoxedOptional => quote! {
155                        crate::parse::traversal::push_proto_field(
156                            y,
157                            &self.#ident,
158                            #ident_str,
159                            true,
160                            |_, _| Ok(()),
161                        );
162                    },
163                    FieldType::Repeated => quote! {
164                        crate::parse::traversal::push_proto_repeated_field(
165                            y,
166                            &self.#ident.as_ref(),
167                            #ident_str,
168                            true,
169                            |_, _| Ok(()),
170                            |_, _, _, _, _| (),
171                        );
172                    },
173                    FieldType::Primitive => quote! {
174                        use crate::input::traits::ProtoPrimitive;
175                        if !y.config.ignore_unknown_fields || !self.#ident.proto_primitive_is_default() {
176                            crate::parse::traversal::push_proto_field(
177                                y,
178                                &Some(&self.#ident),
179                                #ident_str,
180                                true,
181                                |_, _| Ok(()),
182                            );
183                        }
184                    },
185                };
186                quote! {
187                    if !y.field_parsed(#ident_str) {
188                        unknowns = true;
189                        #action
190                    }
191                }
192            } else {
193                panic!("protobuf message fields must have names");
194            }
195        })
196        .collect();
197
198    quote!(
199        impl #impl_generics crate::input::traits::ProtoMessage for #name #ty_generics #where_clause {
200            fn proto_message_type() -> &'static str {
201                use ::once_cell::sync::Lazy;
202                static TYPE_NAME: Lazy<::std::string::String> = Lazy::new(|| {
203                    crate::input::proto::cook_path(module_path!(), #name_str)
204                });
205                &TYPE_NAME
206            }
207        }
208
209        impl #impl_generics crate::input::traits::InputNode for #name #ty_generics #where_clause {
210            fn type_to_node() -> crate::output::tree::Node {
211                use crate::input::traits::ProtoMessage;
212                crate::output::tree::NodeType::ProtoMessage(Self::proto_message_type()).into()
213            }
214
215            fn data_to_node(&self) -> crate::output::tree::Node {
216                use crate::input::traits::ProtoMessage;
217                crate::output::tree::NodeType::ProtoMessage(Self::proto_message_type()).into()
218            }
219
220            fn oneof_variant(&self) -> Option<&'static str> {
221                None
222            }
223
224            fn parse_unknown(
225                &self,
226                y: &mut crate::parse::context::Context<'_>,
227            ) -> bool {
228                let mut unknowns = false;
229                #(#parse_unknown_matches)*
230                unknowns
231            }
232        }
233    )
234    .into()
235}
236
237fn proto_meta_derive_oneof(ast: &syn::DeriveInput, data: &syn::DataEnum) -> TokenStream {
238    let name = &ast.ident;
239    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
240
241    let variant_matches: Vec<_> = data
242        .variants
243        .iter()
244        .map(|variant| {
245            let ident = &variant.ident;
246            let proto_name = cook_ident(ident).to_snake_case();
247            quote! { #name::#ident (_) => #proto_name }
248        })
249        .collect();
250
251    let node_matches: Vec<_> = data
252        .variants
253        .iter()
254        .map(|variant| {
255            let ident = &variant.ident;
256            quote! { #name::#ident (x) => x.data_to_node() }
257        })
258        .collect();
259
260    let parse_unknown_matches: Vec<_> = data
261        .variants
262        .iter()
263        .map(|variant| {
264            let ident = &variant.ident;
265            quote! { #name::#ident (x) => x.parse_unknown(y) }
266        })
267        .collect();
268
269    quote!(
270        impl #impl_generics crate::input::traits::ProtoOneOf for #name #ty_generics #where_clause {
271            fn proto_oneof_variant(&self) -> &'static str {
272                match self {
273                    #(#variant_matches),*
274                }
275            }
276        }
277
278        impl #impl_generics crate::input::traits::InputNode for #name #ty_generics #where_clause {
279            fn type_to_node() -> crate::output::tree::Node {
280                crate::output::tree::NodeType::ProtoMissingOneOf.into()
281            }
282
283            fn data_to_node(&self) -> crate::output::tree::Node {
284                match self {
285                    #(#node_matches),*
286                }
287            }
288
289            fn oneof_variant(&self) -> Option<&'static str> {
290                use crate::input::traits::ProtoOneOf;
291                Some(self.proto_oneof_variant())
292            }
293
294            fn parse_unknown(
295                &self,
296                y: &mut crate::parse::context::Context<'_>,
297            ) -> bool {
298                match self {
299                    #(#parse_unknown_matches),*
300                }
301            }
302        }
303    )
304    .into()
305}
306
307fn proto_meta_derive_enum(ast: &syn::DeriveInput, data: &syn::DataEnum) -> TokenStream {
308    let name = &ast.ident;
309    let name_str = cook_ident(name);
310    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
311
312    let upper_name = name_str.to_shouty_snake_case();
313
314    let variant_names: Vec<_> = data
315        .variants
316        .iter()
317        .map(|variant| {
318            let ident = &variant.ident;
319            let proto_name = format!(
320                "{}_{}",
321                upper_name,
322                cook_ident(ident).to_shouty_snake_case()
323            );
324            (ident, proto_name)
325        })
326        .collect();
327
328    let variant_matches: Vec<_> = variant_names
329        .iter()
330        .map(|(ident, proto_name)| {
331            quote! { #name::#ident => #proto_name }
332        })
333        .collect();
334
335    let (_, first_variant_name) = &variant_names[0];
336
337    quote!(
338        impl #impl_generics crate::input::traits::ProtoEnum for #name #ty_generics #where_clause {
339            fn proto_enum_type() -> &'static str {
340                use ::once_cell::sync::Lazy;
341                static TYPE_NAME: Lazy<::std::string::String> = Lazy::new(|| {
342                    crate::input::proto::cook_path(module_path!(), #name_str)
343                });
344                &TYPE_NAME
345            }
346
347            fn proto_enum_default_variant() -> &'static str {
348                #first_variant_name
349            }
350
351            fn proto_enum_variant(&self) -> &'static str {
352                match self {
353                    #(#variant_matches),*
354                }
355            }
356
357            fn proto_enum_from_i32(x: i32) -> Option<Self> {
358                Self::from_i32(x)
359            }
360        }
361    )
362    .into()
363}