Skip to main content

vld_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Expr, Fields, Lit, Meta};
4
5/// Derive macro that generates `vld_parse()`, `parse_value()`, `validate_fields()`,
6/// and `parse_lenient()` methods for a struct, plus implements the `VldParse` trait.
7///
8/// # Usage
9///
10/// ```ignore
11/// use vld::Validate;
12///
13/// #[derive(Debug, Validate)]
14/// struct User {
15///     #[vld(vld::string().min(2).max(50))]
16///     name: String,
17///     #[vld(vld::string().email())]
18///     email: String,
19///     #[vld(vld::number().int().gte(18).optional())]
20///     age: Option<i64>,
21/// }
22///
23/// let user = User::vld_parse(r#"{"name": "Alex", "email": "a@b.com"}"#).unwrap();
24/// ```
25///
26/// # Serde rename support
27///
28/// The derive macro respects `#[serde(rename = "...")]` on fields and
29/// `#[serde(rename_all = "...")]` on the struct:
30///
31/// ```ignore
32/// #[derive(Debug, serde::Serialize, Validate)]
33/// #[serde(rename_all = "camelCase")]
34/// struct ApiRequest {
35///     #[vld(vld::string().min(2))]
36///     first_name: String,
37///     #[vld(vld::string().email())]
38///     email_address: String,
39/// }
40/// // Parses from {"firstName": "...", "emailAddress": "..."}
41/// ```
42///
43/// Supported rename_all conventions: `camelCase`, `PascalCase`, `snake_case`,
44/// `SCREAMING_SNAKE_CASE`, `kebab-case`, `SCREAMING-KEBAB-CASE`.
45///
46/// The expression inside `#[vld(...)]` is used as-is in the generated code.
47/// Make sure the types are in scope (e.g., use `vld::string()` or import via prelude).
48#[proc_macro_derive(Validate, attributes(vld))]
49pub fn derive_validate(input: TokenStream) -> TokenStream {
50    let input = parse_macro_input!(input as DeriveInput);
51    let name = &input.ident;
52
53    // Check for #[serde(rename_all = "...")]
54    let rename_all = get_serde_rename_all(&input.attrs);
55
56    let fields = match &input.data {
57        Data::Struct(data) => match &data.fields {
58            Fields::Named(fields) => &fields.named,
59            _ => panic!("Validate can only be derived for structs with named fields"),
60        },
61        _ => panic!("Validate can only be derived for structs"),
62    };
63
64    let mut field_names = Vec::new();
65    let mut field_types = Vec::new();
66    let mut field_schemas = Vec::new();
67    let mut field_json_keys = Vec::new();
68
69    for field in fields {
70        let fname = field.ident.as_ref().unwrap();
71        let ftype = &field.ty;
72        field_names.push(fname.clone());
73        field_types.push(ftype.clone());
74
75        // Determine JSON key: #[serde(rename = "...")] > rename_all > field name
76        let json_key = get_serde_rename(&field.attrs).unwrap_or_else(|| {
77            if let Some(ref convention) = rename_all {
78                rename_field(&fname.to_string(), convention)
79            } else {
80                fname.to_string()
81            }
82        });
83        field_json_keys.push(json_key);
84
85        let schema_tokens = field
86            .attrs
87            .iter()
88            .find(|attr| attr.path().is_ident("vld"))
89            .map(|attr| attr.parse_args::<proc_macro2::TokenStream>().unwrap())
90            .unwrap_or_else(|| panic!("Field `{}` is missing #[vld(...)] attribute", fname));
91
92        field_schemas.push(schema_tokens);
93    }
94
95    let expanded = quote! {
96        impl #name {
97            /// Parse and validate input data into this struct.
98            ///
99            /// Named `vld_parse` to avoid conflicts with other derive macros
100            /// (e.g. `clap::Parser::parse()`).
101            pub fn vld_parse<__VldInputT: ::vld::input::VldInput + ?Sized>(
102                input: &__VldInputT,
103            ) -> ::std::result::Result<#name, ::vld::error::VldError> {
104                let __vld_json = <__VldInputT as ::vld::input::VldInput>::to_json_value(input)?;
105                Self::parse_value(&__vld_json)
106            }
107
108            /// Parse and validate directly from a `serde_json::Value`.
109            pub fn parse_value(
110                __vld_json: &::vld::serde_json::Value,
111            ) -> ::std::result::Result<#name, ::vld::error::VldError> {
112                use ::vld::schema::VldSchema as _;
113
114                let __vld_obj = __vld_json.as_object().ok_or_else(|| {
115                    ::vld::error::VldError::single(
116                        ::vld::error::IssueCode::InvalidType {
117                            expected: ::std::string::String::from("object"),
118                            received: ::vld::error::value_type_name(__vld_json),
119                        },
120                        ::std::format!(
121                            "Expected object, received {}",
122                            ::vld::error::value_type_name(__vld_json)
123                        ),
124                    )
125                })?;
126
127                let mut __vld_errors = ::vld::error::VldError::new();
128
129                #(
130                    #[allow(non_snake_case)]
131                    let #field_names: ::std::option::Option<#field_types> = {
132                        let __vld_field_schema = { #field_schemas };
133                        let __vld_field_value = __vld_obj
134                            .get(#field_json_keys)
135                            .unwrap_or(&::vld::serde_json::Value::Null);
136                        match __vld_field_schema.parse_value(__vld_field_value) {
137                            ::std::result::Result::Ok(v) => ::std::option::Option::Some(v),
138                            ::std::result::Result::Err(e) => {
139                                __vld_errors = ::vld::error::VldError::merge(
140                                    __vld_errors,
141                                    ::vld::error::VldError::with_prefix(
142                                        e,
143                                        ::vld::error::PathSegment::Field(
144                                            ::std::string::String::from(#field_json_keys),
145                                        ),
146                                    ),
147                                );
148                                ::std::option::Option::None
149                            }
150                        }
151                    };
152                )*
153
154                if !::vld::error::VldError::is_empty(&__vld_errors) {
155                    return ::std::result::Result::Err(__vld_errors);
156                }
157
158                ::std::result::Result::Ok(#name {
159                    #( #field_names: #field_names.unwrap(), )*
160                })
161            }
162
163            /// Validate each field individually and return per-field results.
164            pub fn validate_fields<__VldInputT: ::vld::input::VldInput + ?Sized>(
165                input: &__VldInputT,
166            ) -> ::std::result::Result<
167                ::std::vec::Vec<::vld::error::FieldResult>,
168                ::vld::error::VldError,
169            > {
170                let __vld_json = <__VldInputT as ::vld::input::VldInput>::to_json_value(input)?;
171                Self::validate_fields_value(&__vld_json)
172            }
173
174            /// Validate each field individually from a `serde_json::Value`.
175            pub fn validate_fields_value(
176                __vld_json: &::vld::serde_json::Value,
177            ) -> ::std::result::Result<
178                ::std::vec::Vec<::vld::error::FieldResult>,
179                ::vld::error::VldError,
180            > {
181                let __vld_obj = __vld_json.as_object().ok_or_else(|| {
182                    ::vld::error::VldError::single(
183                        ::vld::error::IssueCode::InvalidType {
184                            expected: ::std::string::String::from("object"),
185                            received: ::vld::error::value_type_name(__vld_json),
186                        },
187                        ::std::format!(
188                            "Expected object, received {}",
189                            ::vld::error::value_type_name(__vld_json)
190                        ),
191                    )
192                })?;
193
194                let mut __vld_results: ::std::vec::Vec<::vld::error::FieldResult> =
195                    ::std::vec::Vec::new();
196
197                #(
198                    {
199                        let __vld_field_schema = { #field_schemas };
200                        let __vld_field_value = __vld_obj
201                            .get(#field_json_keys)
202                            .unwrap_or(&::vld::serde_json::Value::Null);
203
204                        let __vld_result = ::vld::object::DynSchema::dyn_parse(
205                            &__vld_field_schema,
206                            __vld_field_value,
207                        );
208
209                        __vld_results.push(::vld::error::FieldResult {
210                            name: ::std::string::String::from(#field_json_keys),
211                            input: __vld_field_value.clone(),
212                            result: __vld_result,
213                        });
214                    }
215                )*
216
217                ::std::result::Result::Ok(__vld_results)
218            }
219
220            /// Parse leniently: build the struct even when some fields fail.
221            pub fn parse_lenient<__VldInputT: ::vld::input::VldInput + ?Sized>(
222                input: &__VldInputT,
223            ) -> ::std::result::Result<
224                ::vld::error::ParseResult<#name>,
225                ::vld::error::VldError,
226            > {
227                let __vld_json = <__VldInputT as ::vld::input::VldInput>::to_json_value(input)?;
228                Self::parse_lenient_value(&__vld_json)
229            }
230
231            /// Parse leniently from a `serde_json::Value`.
232            pub fn parse_lenient_value(
233                __vld_json: &::vld::serde_json::Value,
234            ) -> ::std::result::Result<
235                ::vld::error::ParseResult<#name>,
236                ::vld::error::VldError,
237            > {
238                use ::vld::schema::VldSchema as _;
239
240                let __vld_obj = __vld_json.as_object().ok_or_else(|| {
241                    ::vld::error::VldError::single(
242                        ::vld::error::IssueCode::InvalidType {
243                            expected: ::std::string::String::from("object"),
244                            received: ::vld::error::value_type_name(__vld_json),
245                        },
246                        ::std::format!(
247                            "Expected object, received {}",
248                            ::vld::error::value_type_name(__vld_json)
249                        ),
250                    )
251                })?;
252
253                let mut __vld_results: ::std::vec::Vec<::vld::error::FieldResult> =
254                    ::std::vec::Vec::new();
255
256                #(
257                    #[allow(non_snake_case)]
258                    let #field_names: #field_types = {
259                        let __vld_field_schema = { #field_schemas };
260                        let __vld_field_value = __vld_obj
261                            .get(#field_json_keys)
262                            .unwrap_or(&::vld::serde_json::Value::Null);
263
264                        match __vld_field_schema.parse_value(__vld_field_value) {
265                            ::std::result::Result::Ok(v) => {
266                                let __json_repr = ::vld::serde_json::to_value(&v)
267                                    .unwrap_or_else(|_| __vld_field_value.clone());
268                                __vld_results.push(::vld::error::FieldResult {
269                                    name: ::std::string::String::from(#field_json_keys),
270                                    input: __vld_field_value.clone(),
271                                    result: ::std::result::Result::Ok(__json_repr),
272                                });
273                                v
274                            }
275                            ::std::result::Result::Err(e) => {
276                                __vld_results.push(::vld::error::FieldResult {
277                                    name: ::std::string::String::from(#field_json_keys),
278                                    input: __vld_field_value.clone(),
279                                    result: ::std::result::Result::Err(e),
280                                });
281                                <#field_types as ::std::default::Default>::default()
282                            }
283                        }
284                    };
285                )*
286
287                let __vld_struct = #name {
288                    #( #field_names, )*
289                };
290
291                ::std::result::Result::Ok(
292                    ::vld::error::ParseResult::new(__vld_struct, __vld_results)
293                )
294            }
295        }
296
297        impl ::vld::schema::VldParse for #name {
298            fn vld_parse_value(
299                value: &::vld::serde_json::Value,
300            ) -> ::std::result::Result<Self, ::vld::error::VldError> {
301                Self::parse_value(value)
302            }
303        }
304    };
305
306    TokenStream::from(expanded)
307}
308
309// ---------------------------------------------------------------------------
310// Serde attribute parsing helpers
311// ---------------------------------------------------------------------------
312
313/// Extract `#[serde(rename_all = "...")]` from struct-level attributes.
314fn get_serde_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
315    for attr in attrs {
316        if !attr.path().is_ident("serde") {
317            continue;
318        }
319        if let Ok(nested) = attr
320            .parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
321        {
322            for meta in &nested {
323                if let Meta::NameValue(nv) = meta {
324                    if nv.path.is_ident("rename_all") {
325                        if let Expr::Lit(lit) = &nv.value {
326                            if let Lit::Str(s) = &lit.lit {
327                                return Some(s.value());
328                            }
329                        }
330                    }
331                }
332            }
333        }
334    }
335    None
336}
337
338/// Extract `#[serde(rename = "...")]` from field-level attributes.
339fn get_serde_rename(attrs: &[syn::Attribute]) -> Option<String> {
340    for attr in attrs {
341        if !attr.path().is_ident("serde") {
342            continue;
343        }
344        if let Ok(nested) = attr
345            .parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
346        {
347            for meta in &nested {
348                if let Meta::NameValue(nv) = meta {
349                    if nv.path.is_ident("rename") {
350                        if let Expr::Lit(lit) = &nv.value {
351                            if let Lit::Str(s) = &lit.lit {
352                                return Some(s.value());
353                            }
354                        }
355                    }
356                }
357            }
358        }
359    }
360    None
361}
362
363/// Convert a snake_case field name to the given naming convention.
364fn rename_field(name: &str, convention: &str) -> String {
365    match convention {
366        "camelCase" => to_camel_case(name),
367        "PascalCase" => to_pascal_case(name),
368        "snake_case" => name.to_string(),
369        "SCREAMING_SNAKE_CASE" => name.to_uppercase(),
370        "kebab-case" => name.replace('_', "-"),
371        "SCREAMING-KEBAB-CASE" => name.replace('_', "-").to_uppercase(),
372        _ => name.to_string(),
373    }
374}
375
376fn to_camel_case(s: &str) -> String {
377    let mut result = String::new();
378    let mut capitalize_next = false;
379    for ch in s.chars() {
380        if ch == '_' {
381            capitalize_next = true;
382        } else if capitalize_next {
383            result.extend(ch.to_uppercase());
384            capitalize_next = false;
385        } else {
386            result.push(ch);
387        }
388    }
389    result
390}
391
392fn to_pascal_case(s: &str) -> String {
393    let mut result = String::new();
394    let mut capitalize_next = true;
395    for ch in s.chars() {
396        if ch == '_' {
397            capitalize_next = true;
398        } else if capitalize_next {
399            result.extend(ch.to_uppercase());
400            capitalize_next = false;
401        } else {
402            result.push(ch);
403        }
404    }
405    result
406}