skp_validator_derive/
lib.rs

1use proc_macro::TokenStream;
2use syn::{parse_macro_input, DeriveInput, Data, Fields};
3use quote::quote;
4
5mod parser;
6mod schema_codegen;
7
8use parser::{ValidationRule, parse_validate_attribute};
9use schema_codegen::generate_metadata_impl;
10
11/// Derive macro for implementing the Validate trait.
12#[proc_macro_derive(Validate, attributes(validate))]
13pub fn derive_validate(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = &input.ident;
16    let generics = &input.generics;
17    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
18    
19    // Get fields from struct
20    let fields = match &input.data {
21        Data::Struct(data) => match &data.fields {
22            Fields::Named(fields) => &fields.named,
23            Fields::Unnamed(_) => {
24                return syn::Error::new_spanned(
25                    &input,
26                    "Validate can only be derived for structs with named fields"
27                )
28                .to_compile_error()
29                .into();
30            }
31            Fields::Unit => {
32                return syn::Error::new_spanned(
33                    &input,
34                    "Validate cannot be derived for unit structs"
35                )
36                .to_compile_error()
37                .into();
38            }
39        },
40        Data::Enum(_) => {
41            return syn::Error::new_spanned(
42                &input,
43                "Validate for enums is not yet implemented"
44            )
45            .to_compile_error()
46            .into();
47        }
48        Data::Union(_) => {
49            return syn::Error::new_spanned(
50                &input,
51                "Validate cannot be derived for unions"
52            )
53            .to_compile_error()
54            .into();
55        }
56    };
57    
58    // Generate field validation code
59    let field_validations: Vec<_> = fields.iter().filter_map(|field| {
60        generate_field_validation(field)
61    }).collect();
62    
63    // Generate schema metadata
64    let fields_named = match &input.data {
65        Data::Struct(data) => match &data.fields {
66            Fields::Named(fields) => fields,
67            _ => unreachable!(),
68        },
69        _ => unreachable!(),
70    };
71    let metadata_impl = generate_metadata_impl(name, generics, fields_named);
72    
73    // Generate implementation
74    let expanded = quote! {
75        impl #impl_generics skp_validator_core::Validate for #name #ty_generics #where_clause {
76            fn validate_with_context(
77                &self,
78                ctx: &skp_validator_core::ValidationContext
79            ) -> skp_validator_core::ValidationResult<()> {
80                let mut errors = skp_validator_core::ValidationErrors::new();
81                
82                #(#field_validations)*
83                
84                if errors.is_empty() {
85                    Ok(())
86                } else {
87                    Err(errors)
88                }
89            }
90        }
91        
92        #metadata_impl
93    };
94    
95    TokenStream::from(expanded)
96}
97
98fn generate_field_validation(field: &syn::Field) -> Option<proc_macro2::TokenStream> {
99    let field_name = field.ident.as_ref().unwrap();
100    let field_name_str = field_name.to_string();
101    
102    // Check validation attribute
103    if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("validate")) {
104        let rules = match parse_validate_attribute(attr) {
105            Ok(rules) => rules,
106            Err(e) => {
107                let err_msg = e.to_string();
108                return Some(quote! { compile_error!(#err_msg); });
109            }
110        };
111        
112        let is_option = is_option(&field.ty);
113        let field_type = &field.ty;
114        
115        let validations: Vec<_> = rules.iter().filter_map(|rule| {
116             generate_rule_validation(field_name, &field_name_str, field_type, rule, is_option)
117        }).collect();
118        
119        Some(quote! {
120            #(#validations)*
121        })
122    } else {
123        None
124    }
125}
126
127fn generate_rule_validation(
128    field_name: &syn::Ident,
129    field_name_str: &str,
130    field_type: &syn::Type,
131    rule: &ValidationRule,
132    is_option: bool
133) -> Option<proc_macro2::TokenStream> {
134    match rule {
135        ValidationRule::Skip => None,
136        
137        ValidationRule::Required { message } => {
138            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("field is required".to_string()));
139            Some(quote! {
140                if self.#field_name == <#field_type as Default>::default() {
141                     errors.add_field_error(
142                        #field_name_str, 
143                        skp_validator_core::ValidationError::new(
144                            #field_name_str,
145                            "required", 
146                            #error_message
147                        )
148                     );
149                }
150            })
151        },
152        
153
154        ValidationRule::Nested => {
155            Some(quote! {
156                if let Err(mut nested_errors) = self.#field_name.validate_with_context(ctx) {
157                    errors.add_nested_errors(#field_name_str, nested_errors);
158                }
159            })
160        },
161        
162        ValidationRule::Dive => {
163            Some(quote! {
164                use skp_validator_core::ValidateDive;
165                let path = skp_validator_core::FieldPath::from_field(#field_name_str);
166                if let Err(dive_errors) = self.#field_name.validate_dive(&path, ctx) {
167                    errors.merge(dive_errors);
168                }
169            })
170        },
171
172        _ => {
173            let rule_check = generate_leaf_rule_check(rule, field_name, field_name_str);
174            if let Some(check) = rule_check {
175                if is_option {
176                     Some(quote! {
177                         if let Some(ref val) = self.#field_name {
178                             #check
179                         }
180                     })
181                } else {
182                     Some(quote! {
183                         let val = &self.#field_name;
184                         #check
185                     })
186                }
187            } else {
188                None
189            }
190        }
191    }
192}
193
194fn generate_leaf_rule_check(
195    rule: &ValidationRule,
196    _field_ident: &syn::Ident,
197    field_name_str: &str
198) -> Option<proc_macro2::TokenStream> {
199    match rule {
200        ValidationRule::Length { min, max, equal, message } => {
201             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid length".to_string()));
202             let min = quote_option_usize(min);
203             let max = quote_option_usize(max);
204             let equal = quote_option_usize(equal);
205             Some(quote! {
206                 let len = val.len();
207                 let mut valid = true;
208                 if let Some(m) = #min { if len < m { valid = false; } }
209                 if let Some(m) = #max { if len > m { valid = false; } }
210                 if let Some(e) = #equal { if len != e { valid = false; } }
211                 if !valid {
212                      errors.add_field_error(
213                        #field_name_str, 
214                        skp_validator_core::ValidationError::new(
215                            #field_name_str,
216                            "length", 
217                            #error_message
218                        )
219                      );
220                 }
221             })
222        },
223
224        ValidationRule::Range { min, max, message } => {
225             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("value out of range".to_string()));
226             
227             let min_check = if let Some(m) = min {
228                 quote! { if *val < (#m as _) { valid = false; } }
229             } else {
230                 quote! {}
231             };
232             
233             let max_check = if let Some(m) = max {
234                 quote! { if *val > (#m as _) { valid = false; } }
235             } else {
236                 quote! {}
237             };
238
239             Some(quote! {
240                 let mut valid = true;
241                 #min_check
242                 #max_check
243                 if !valid {
244                      errors.add_field_error(
245                        #field_name_str,
246                        skp_validator_core::ValidationError::new(
247                            #field_name_str,
248                            "range", 
249                            #error_message
250                        )
251                      );
252                 }
253             })
254        },
255
256        ValidationRule::Email { message } => {
257             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid email".to_string()));
258             Some(quote! {
259                 if !val.contains('@') {
260                      errors.add_field_error(
261                        #field_name_str,
262                        skp_validator_core::ValidationError::new(
263                            #field_name_str,
264                            "email", 
265                            #error_message
266                        )
267                      );
268                 }
269             })
270        },
271        
272        ValidationRule::Url { message } => {
273             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid url".to_string()));
274             Some(quote! {
275                 if !val.starts_with("http") {
276                      errors.add_field_error(
277                        #field_name_str,
278                        skp_validator_core::ValidationError::new(
279                            #field_name_str,
280                            "url", 
281                            #error_message
282                        )
283                      );
284                 }
285             })
286        },
287
288        ValidationRule::Ip { version, message } => {
289            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid IP address".to_string()));
290            let version = quote_option_string(version);
291            Some(quote! {
292                let val_str = val.to_string();
293                if val_str.parse::<std::net::IpAddr>().is_err() {
294                     errors.add_field_error(
295                        #field_name_str,
296                        skp_validator_core::ValidationError::new(
297                            #field_name_str,
298                            "ip", 
299                            #error_message
300                        )
301                     );
302                } else if let Some(ver) = #version {
303                     let ip: std::net::IpAddr = val_str.parse().unwrap();
304                     if ver == "v4" && !ip.is_ipv4() {
305                         errors.add_field_error(
306                            #field_name_str,
307                            skp_validator_core::ValidationError::new(
308                                #field_name_str,
309                                "ip", 
310                                "Expected IPv4".to_string()
311                            )
312                         );
313                     } else if ver == "v6" && !ip.is_ipv6() {
314                         errors.add_field_error(
315                            #field_name_str,
316                            skp_validator_core::ValidationError::new(
317                                #field_name_str,
318                                "ip", 
319                                "Expected IPv6".to_string()
320                            )
321                         );
322                     }
323                }
324            })
325        },
326        
327        ValidationRule::Uuid { version, message } => {
328            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid UUID".to_string()));
329            let version = quote_option_usize(version);
330            Some(quote! {
331                use skp_validator_core::Rule;
332                let mut rule = skp_validator::rules::UuidRule::new();
333                if let Some(v) = #version {
334                    rule = rule.version(v as u8);
335                }
336                rule = rule.message(#error_message);
337                
338                if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
339                     for err in e.errors {
340                         errors.add_field_error(#field_name_str, err);
341                     }
342                }
343            })
344        },
345
346        ValidationRule::Phone { message } => {
347            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid phone number".to_string()));
348            Some(quote! {
349                 use skp_validator_core::Rule;
350                 let rule = skp_validator::rules::PhoneRule::new().message(#error_message);
351                 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
352                     for err in e.errors {
353                         errors.add_field_error(#field_name_str, err);
354                     }
355                 }
356            })
357        },
358        
359        ValidationRule::Prefix { value, message } => {
360            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid prefix".to_string()));
361            Some(quote! {
362                if !val.starts_with(#value) {
363                     errors.add_field_error(
364                        #field_name_str,
365                        skp_validator_core::ValidationError::new(
366                            #field_name_str,
367                            "prefix", 
368                            #error_message
369                        )
370                     );
371                }
372            })
373        },
374
375        ValidationRule::Suffix { value, message } => {
376            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid suffix".to_string()));
377            Some(quote! {
378                if !val.ends_with(#value) {
379                     errors.add_field_error(
380                        #field_name_str,
381                        skp_validator_core::ValidationError::new(
382                            #field_name_str,
383                            "suffix", 
384                            #error_message
385                        )
386                     );
387                }
388            })
389        },
390
391        ValidationRule::Contains { value, message } => {
392            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain value".to_string()));
393            Some(quote! {
394                if !val.contains(#value) {
395                     errors.add_field_error(
396                        #field_name_str,
397                        skp_validator_core::ValidationError::new(
398                            #field_name_str,
399                            "contains", 
400                            #error_message
401                        )
402                     );
403                }
404            })
405        },
406        
407        ValidationRule::Trim { message } => {
408            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be trimmed".to_string()));
409            Some(quote! {
410                if val.trim() != val {
411                     errors.add_field_error(
412                        #field_name_str,
413                        skp_validator_core::ValidationError::new(
414                            #field_name_str,
415                            "trim", 
416                            #error_message
417                        )
418                     );
419                }
420            })
421        },
422        
423        ValidationRule::Uppercase { message } => {
424            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be uppercase".to_string()));
425            Some(quote! {
426                if val.chars().any(|c| c.is_lowercase()) {
427                     errors.add_field_error(
428                        #field_name_str,
429                        skp_validator_core::ValidationError::new(
430                            #field_name_str,
431                            "uppercase", 
432                            #error_message
433                        )
434                     );
435                }
436            })
437        },
438
439        ValidationRule::Lowercase { message } => {
440             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be lowercase".to_string()));
441             Some(quote! {
442                 if val.chars().any(|c| c.is_uppercase()) {
443                      errors.add_field_error(
444                         #field_name_str,
445                         skp_validator_core::ValidationError::new(
446                             #field_name_str,
447                             "lowercase", 
448                             #error_message
449                         )
450                      );
451                 }
452             })
453        },
454
455        ValidationRule::MultipleOf { value, message } => {
456            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Not multiple of value".to_string()));
457            Some(quote! {
458                if val % (#value as _) != 0 {
459                     errors.add_field_error(
460                        #field_name_str,
461                        skp_validator_core::ValidationError::new(
462                            #field_name_str,
463                            "multiple_of", 
464                            #error_message
465                        )
466                     );
467                }
468            })
469        },
470        
471        ValidationRule::AllowedValues { values, message } => {
472            let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Value not allowed".to_string()));
473            let value_tokens = values.iter().map(|v| quote!(#v));
474            Some(quote! {
475                let allowed = vec![#(#value_tokens),*];
476                if !allowed.contains(&val.to_string().as_str()) {
477                     errors.add_field_error(
478                        #field_name_str,
479                        skp_validator_core::ValidationError::new(
480                            #field_name_str,
481                            "allowed_values", 
482                            #error_message
483                        )
484                     );
485                }
486            })
487        },
488        
489        ValidationRule::MustMatch { other, message } => {
490             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Field mismatch".to_string()));
491             let other_ident = syn::Ident::new(other, proc_macro2::Span::call_site());
492             Some(quote! {
493                 if val != &self.#other_ident {
494                      errors.add_field_error(
495                        #field_name_str,
496                        skp_validator_core::ValidationError::new(
497                            #field_name_str,
498                            "must_match", 
499                            #error_message
500                        )
501                      );
502                 }
503             })
504        },
505        
506        ValidationRule::CreditCard { message } => {
507             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid credit card number".to_string()));
508             Some(quote! {
509                 let val_str = val.to_string();
510                 let mut sum = 0;
511                 let mut double = false;
512                 let mut valid = true;
513                 for c in val_str.chars().rev() {
514                     if let Some(mut digit) = c.to_digit(10) {
515                         if double {
516                             digit *= 2;
517                             if digit > 9 { digit -= 9; }
518                         }
519                         sum += digit;
520                         double = !double;
521                     } else {
522                         valid = false;
523                         break;
524                     }
525                 }
526                 if !valid || sum % 10 != 0 {
527                      errors.add_field_error(
528                        #field_name_str,
529                        skp_validator_core::ValidationError::new(
530                            #field_name_str,
531                            "credit_card", 
532                            #error_message
533                        )
534                      );
535                 }
536             })
537        },
538        
539        ValidationRule::Pattern { regex, message } => {
540             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid format".to_string()));
541             Some(quote! {
542                 use skp_validator_core::Rule;
543                 let rule = skp_validator::rules::PatternRule::new(#regex).message(#error_message);
544                 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
545                     for err in e.errors {
546                         errors.add_field_error(#field_name_str, err);
547                     }
548                 }
549             })
550        },
551        
552        ValidationRule::Custom { function, message } => {
553            let function_path: syn::Path = syn::parse_str(function).expect("Invalid function path");
554            let message_override = if let Some(msg) = message {
555                quote! { e.message = #msg.to_string(); }
556            } else {
557                quote! {}
558            };
559            Some(quote! {
560                if let Err(mut e) = #function_path(&val) {
561                    #message_override
562                    errors.add_field_error(#field_name_str, e);
563                }
564            })
565        },
566        
567        ValidationRule::Ascii { message } => {
568             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain only ASCII characters".to_string()));
569             Some(quote! {
570                 use skp_validator_core::Rule;
571                 let rule = skp_validator::rules::AsciiRule::new().message(#error_message);
572                 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
573                     for err in e.errors {
574                         errors.add_field_error(#field_name_str, err);
575                     }
576                 }
577             })
578        },
579
580        ValidationRule::Alphanumeric { message } => {
581             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain only alphanumeric characters".to_string()));
582             Some(quote! {
583                 use skp_validator_core::Rule;
584                 let rule = skp_validator::rules::AlphanumericRule::new().message(#error_message);
585                 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
586                     for err in e.errors {
587                         errors.add_field_error(#field_name_str, err);
588                     }
589                 }
590             })
591        },
592
593        ValidationRule::UniqueItems { message } => {
594             let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Items must be unique".to_string()));
595             Some(quote! {
596                 use skp_validator_core::Rule;
597                 let rule = skp_validator::rules::UniqueItemsRule::new().message(#error_message);
598                 if let Err(mut e) = rule.validate(val, ctx) {
599                     for err in e.errors {
600                         errors.add_field_error(#field_name_str, err);
601                     }
602                 }
603             })
604        },
605
606        _ => None
607    }
608}
609
610fn quote_option_usize(opt: &Option<usize>) -> proc_macro2::TokenStream {
611    match opt {
612        Some(v) => quote!(Some(#v as usize)),
613        None => quote!(None::<usize>),
614    }
615}
616
617fn quote_option_string(opt: &Option<String>) -> proc_macro2::TokenStream {
618    match opt {
619        Some(v) => quote!(Some(#v)),
620        None => quote!(None::<String>),
621    }
622}
623
624fn is_option(ty: &syn::Type) -> bool {
625    if let syn::Type::Path(type_path) = ty
626         && let Some(segment) = type_path.path.segments.last()
627         && segment.ident == "Option"
628    {
629        return true;
630    }
631    false
632}