1use proc_macro::TokenStream;
2use syn::{parse_macro_input, DeriveInput, Data, Fields};
3use quote::quote;
4
5mod parser;
6mod schema_codegen;
7
8use parser::{ValidationRule, parse_validate_attribute};
9use schema_codegen::generate_metadata_impl;
10
11#[proc_macro_derive(Validate, attributes(validate))]
13pub fn derive_validate(input: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(input as DeriveInput);
15 let name = &input.ident;
16 let generics = &input.generics;
17 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
18
19 let fields = match &input.data {
21 Data::Struct(data) => match &data.fields {
22 Fields::Named(fields) => &fields.named,
23 Fields::Unnamed(_) => {
24 return syn::Error::new_spanned(
25 &input,
26 "Validate can only be derived for structs with named fields"
27 )
28 .to_compile_error()
29 .into();
30 }
31 Fields::Unit => {
32 return syn::Error::new_spanned(
33 &input,
34 "Validate cannot be derived for unit structs"
35 )
36 .to_compile_error()
37 .into();
38 }
39 },
40 Data::Enum(_) => {
41 return syn::Error::new_spanned(
42 &input,
43 "Validate for enums is not yet implemented"
44 )
45 .to_compile_error()
46 .into();
47 }
48 Data::Union(_) => {
49 return syn::Error::new_spanned(
50 &input,
51 "Validate cannot be derived for unions"
52 )
53 .to_compile_error()
54 .into();
55 }
56 };
57
58 let field_validations: Vec<_> = fields.iter().filter_map(|field| {
60 generate_field_validation(field)
61 }).collect();
62
63 let fields_named = match &input.data {
65 Data::Struct(data) => match &data.fields {
66 Fields::Named(fields) => fields,
67 _ => unreachable!(),
68 },
69 _ => unreachable!(),
70 };
71 let metadata_impl = generate_metadata_impl(name, generics, fields_named);
72
73 let expanded = quote! {
75 impl #impl_generics skp_validator_core::Validate for #name #ty_generics #where_clause {
76 fn validate_with_context(
77 &self,
78 ctx: &skp_validator_core::ValidationContext
79 ) -> skp_validator_core::ValidationResult<()> {
80 let mut errors = skp_validator_core::ValidationErrors::new();
81
82 #(#field_validations)*
83
84 if errors.is_empty() {
85 Ok(())
86 } else {
87 Err(errors)
88 }
89 }
90 }
91
92 #metadata_impl
93 };
94
95 TokenStream::from(expanded)
96}
97
98fn generate_field_validation(field: &syn::Field) -> Option<proc_macro2::TokenStream> {
99 let field_name = field.ident.as_ref().unwrap();
100 let field_name_str = field_name.to_string();
101
102 if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("validate")) {
104 let rules = match parse_validate_attribute(attr) {
105 Ok(rules) => rules,
106 Err(e) => {
107 let err_msg = e.to_string();
108 return Some(quote! { compile_error!(#err_msg); });
109 }
110 };
111
112 let is_option = is_option(&field.ty);
113 let field_type = &field.ty;
114
115 let validations: Vec<_> = rules.iter().filter_map(|rule| {
116 generate_rule_validation(field_name, &field_name_str, field_type, rule, is_option)
117 }).collect();
118
119 Some(quote! {
120 #(#validations)*
121 })
122 } else {
123 None
124 }
125}
126
127fn generate_rule_validation(
128 field_name: &syn::Ident,
129 field_name_str: &str,
130 field_type: &syn::Type,
131 rule: &ValidationRule,
132 is_option: bool
133) -> Option<proc_macro2::TokenStream> {
134 match rule {
135 ValidationRule::Skip => None,
136
137 ValidationRule::Required { message } => {
138 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("field is required".to_string()));
139 Some(quote! {
140 if self.#field_name == <#field_type as Default>::default() {
141 errors.add_field_error(
142 #field_name_str,
143 skp_validator_core::ValidationError::new(
144 #field_name_str,
145 "required",
146 #error_message
147 )
148 );
149 }
150 })
151 },
152
153
154 ValidationRule::Nested => {
155 Some(quote! {
156 if let Err(mut nested_errors) = self.#field_name.validate_with_context(ctx) {
157 errors.add_nested_errors(#field_name_str, nested_errors);
158 }
159 })
160 },
161
162 ValidationRule::Dive => {
163 Some(quote! {
164 use skp_validator_core::ValidateDive;
165 let path = skp_validator_core::FieldPath::from_field(#field_name_str);
166 if let Err(dive_errors) = self.#field_name.validate_dive(&path, ctx) {
167 errors.merge(dive_errors);
168 }
169 })
170 },
171
172 _ => {
173 let rule_check = generate_leaf_rule_check(rule, field_name, field_name_str);
174 if let Some(check) = rule_check {
175 if is_option {
176 Some(quote! {
177 if let Some(ref val) = self.#field_name {
178 #check
179 }
180 })
181 } else {
182 Some(quote! {
183 let val = &self.#field_name;
184 #check
185 })
186 }
187 } else {
188 None
189 }
190 }
191 }
192}
193
194fn generate_leaf_rule_check(
195 rule: &ValidationRule,
196 _field_ident: &syn::Ident,
197 field_name_str: &str
198) -> Option<proc_macro2::TokenStream> {
199 match rule {
200 ValidationRule::Length { min, max, equal, message } => {
201 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid length".to_string()));
202 let min = quote_option_usize(min);
203 let max = quote_option_usize(max);
204 let equal = quote_option_usize(equal);
205 Some(quote! {
206 let len = val.len();
207 let mut valid = true;
208 if let Some(m) = #min { if len < m { valid = false; } }
209 if let Some(m) = #max { if len > m { valid = false; } }
210 if let Some(e) = #equal { if len != e { valid = false; } }
211 if !valid {
212 errors.add_field_error(
213 #field_name_str,
214 skp_validator_core::ValidationError::new(
215 #field_name_str,
216 "length",
217 #error_message
218 )
219 );
220 }
221 })
222 },
223
224 ValidationRule::Range { min, max, message } => {
225 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("value out of range".to_string()));
226
227 let min_check = if let Some(m) = min {
228 quote! { if *val < (#m as _) { valid = false; } }
229 } else {
230 quote! {}
231 };
232
233 let max_check = if let Some(m) = max {
234 quote! { if *val > (#m as _) { valid = false; } }
235 } else {
236 quote! {}
237 };
238
239 Some(quote! {
240 let mut valid = true;
241 #min_check
242 #max_check
243 if !valid {
244 errors.add_field_error(
245 #field_name_str,
246 skp_validator_core::ValidationError::new(
247 #field_name_str,
248 "range",
249 #error_message
250 )
251 );
252 }
253 })
254 },
255
256 ValidationRule::Email { message } => {
257 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid email".to_string()));
258 Some(quote! {
259 if !val.contains('@') {
260 errors.add_field_error(
261 #field_name_str,
262 skp_validator_core::ValidationError::new(
263 #field_name_str,
264 "email",
265 #error_message
266 )
267 );
268 }
269 })
270 },
271
272 ValidationRule::Url { message } => {
273 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("invalid url".to_string()));
274 Some(quote! {
275 if !val.starts_with("http") {
276 errors.add_field_error(
277 #field_name_str,
278 skp_validator_core::ValidationError::new(
279 #field_name_str,
280 "url",
281 #error_message
282 )
283 );
284 }
285 })
286 },
287
288 ValidationRule::Ip { version, message } => {
289 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid IP address".to_string()));
290 let version = quote_option_string(version);
291 Some(quote! {
292 let val_str = val.to_string();
293 if val_str.parse::<std::net::IpAddr>().is_err() {
294 errors.add_field_error(
295 #field_name_str,
296 skp_validator_core::ValidationError::new(
297 #field_name_str,
298 "ip",
299 #error_message
300 )
301 );
302 } else if let Some(ver) = #version {
303 let ip: std::net::IpAddr = val_str.parse().unwrap();
304 if ver == "v4" && !ip.is_ipv4() {
305 errors.add_field_error(
306 #field_name_str,
307 skp_validator_core::ValidationError::new(
308 #field_name_str,
309 "ip",
310 "Expected IPv4".to_string()
311 )
312 );
313 } else if ver == "v6" && !ip.is_ipv6() {
314 errors.add_field_error(
315 #field_name_str,
316 skp_validator_core::ValidationError::new(
317 #field_name_str,
318 "ip",
319 "Expected IPv6".to_string()
320 )
321 );
322 }
323 }
324 })
325 },
326
327 ValidationRule::Uuid { version, message } => {
328 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid UUID".to_string()));
329 let version = quote_option_usize(version);
330 Some(quote! {
331 use skp_validator_core::Rule;
332 let mut rule = skp_validator::rules::UuidRule::new();
333 if let Some(v) = #version {
334 rule = rule.version(v as u8);
335 }
336 rule = rule.message(#error_message);
337
338 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
339 for err in e.errors {
340 errors.add_field_error(#field_name_str, err);
341 }
342 }
343 })
344 },
345
346 ValidationRule::Phone { message } => {
347 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid phone number".to_string()));
348 Some(quote! {
349 use skp_validator_core::Rule;
350 let rule = skp_validator::rules::PhoneRule::new().message(#error_message);
351 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
352 for err in e.errors {
353 errors.add_field_error(#field_name_str, err);
354 }
355 }
356 })
357 },
358
359 ValidationRule::Prefix { value, message } => {
360 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid prefix".to_string()));
361 Some(quote! {
362 if !val.starts_with(#value) {
363 errors.add_field_error(
364 #field_name_str,
365 skp_validator_core::ValidationError::new(
366 #field_name_str,
367 "prefix",
368 #error_message
369 )
370 );
371 }
372 })
373 },
374
375 ValidationRule::Suffix { value, message } => {
376 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid suffix".to_string()));
377 Some(quote! {
378 if !val.ends_with(#value) {
379 errors.add_field_error(
380 #field_name_str,
381 skp_validator_core::ValidationError::new(
382 #field_name_str,
383 "suffix",
384 #error_message
385 )
386 );
387 }
388 })
389 },
390
391 ValidationRule::Contains { value, message } => {
392 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain value".to_string()));
393 Some(quote! {
394 if !val.contains(#value) {
395 errors.add_field_error(
396 #field_name_str,
397 skp_validator_core::ValidationError::new(
398 #field_name_str,
399 "contains",
400 #error_message
401 )
402 );
403 }
404 })
405 },
406
407 ValidationRule::Trim { message } => {
408 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be trimmed".to_string()));
409 Some(quote! {
410 if val.trim() != val {
411 errors.add_field_error(
412 #field_name_str,
413 skp_validator_core::ValidationError::new(
414 #field_name_str,
415 "trim",
416 #error_message
417 )
418 );
419 }
420 })
421 },
422
423 ValidationRule::Uppercase { message } => {
424 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be uppercase".to_string()));
425 Some(quote! {
426 if val.chars().any(|c| c.is_lowercase()) {
427 errors.add_field_error(
428 #field_name_str,
429 skp_validator_core::ValidationError::new(
430 #field_name_str,
431 "uppercase",
432 #error_message
433 )
434 );
435 }
436 })
437 },
438
439 ValidationRule::Lowercase { message } => {
440 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must be lowercase".to_string()));
441 Some(quote! {
442 if val.chars().any(|c| c.is_uppercase()) {
443 errors.add_field_error(
444 #field_name_str,
445 skp_validator_core::ValidationError::new(
446 #field_name_str,
447 "lowercase",
448 #error_message
449 )
450 );
451 }
452 })
453 },
454
455 ValidationRule::MultipleOf { value, message } => {
456 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Not multiple of value".to_string()));
457 Some(quote! {
458 if val % (#value as _) != 0 {
459 errors.add_field_error(
460 #field_name_str,
461 skp_validator_core::ValidationError::new(
462 #field_name_str,
463 "multiple_of",
464 #error_message
465 )
466 );
467 }
468 })
469 },
470
471 ValidationRule::AllowedValues { values, message } => {
472 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Value not allowed".to_string()));
473 let value_tokens = values.iter().map(|v| quote!(#v));
474 Some(quote! {
475 let allowed = vec![#(#value_tokens),*];
476 if !allowed.contains(&val.to_string().as_str()) {
477 errors.add_field_error(
478 #field_name_str,
479 skp_validator_core::ValidationError::new(
480 #field_name_str,
481 "allowed_values",
482 #error_message
483 )
484 );
485 }
486 })
487 },
488
489 ValidationRule::MustMatch { other, message } => {
490 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Field mismatch".to_string()));
491 let other_ident = syn::Ident::new(other, proc_macro2::Span::call_site());
492 Some(quote! {
493 if val != &self.#other_ident {
494 errors.add_field_error(
495 #field_name_str,
496 skp_validator_core::ValidationError::new(
497 #field_name_str,
498 "must_match",
499 #error_message
500 )
501 );
502 }
503 })
504 },
505
506 ValidationRule::CreditCard { message } => {
507 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid credit card number".to_string()));
508 Some(quote! {
509 let val_str = val.to_string();
510 let mut sum = 0;
511 let mut double = false;
512 let mut valid = true;
513 for c in val_str.chars().rev() {
514 if let Some(mut digit) = c.to_digit(10) {
515 if double {
516 digit *= 2;
517 if digit > 9 { digit -= 9; }
518 }
519 sum += digit;
520 double = !double;
521 } else {
522 valid = false;
523 break;
524 }
525 }
526 if !valid || sum % 10 != 0 {
527 errors.add_field_error(
528 #field_name_str,
529 skp_validator_core::ValidationError::new(
530 #field_name_str,
531 "credit_card",
532 #error_message
533 )
534 );
535 }
536 })
537 },
538
539 ValidationRule::Pattern { regex, message } => {
540 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Invalid format".to_string()));
541 Some(quote! {
542 use skp_validator_core::Rule;
543 let rule = skp_validator::rules::PatternRule::new(#regex).message(#error_message);
544 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
545 for err in e.errors {
546 errors.add_field_error(#field_name_str, err);
547 }
548 }
549 })
550 },
551
552 ValidationRule::Custom { function, message } => {
553 let function_path: syn::Path = syn::parse_str(function).expect("Invalid function path");
554 let message_override = if let Some(msg) = message {
555 quote! { e.message = #msg.to_string(); }
556 } else {
557 quote! {}
558 };
559 Some(quote! {
560 if let Err(mut e) = #function_path(&val) {
561 #message_override
562 errors.add_field_error(#field_name_str, e);
563 }
564 })
565 },
566
567 ValidationRule::Ascii { message } => {
568 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain only ASCII characters".to_string()));
569 Some(quote! {
570 use skp_validator_core::Rule;
571 let rule = skp_validator::rules::AsciiRule::new().message(#error_message);
572 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
573 for err in e.errors {
574 errors.add_field_error(#field_name_str, err);
575 }
576 }
577 })
578 },
579
580 ValidationRule::Alphanumeric { message } => {
581 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Must contain only alphanumeric characters".to_string()));
582 Some(quote! {
583 use skp_validator_core::Rule;
584 let rule = skp_validator::rules::AlphanumericRule::new().message(#error_message);
585 if let Err(mut e) = rule.validate(&val.to_string(), ctx) {
586 for err in e.errors {
587 errors.add_field_error(#field_name_str, err);
588 }
589 }
590 })
591 },
592
593 ValidationRule::UniqueItems { message } => {
594 let error_message = message.as_ref().map(|m| quote!(#m.to_string())).unwrap_or_else(|| quote!("Items must be unique".to_string()));
595 Some(quote! {
596 use skp_validator_core::Rule;
597 let rule = skp_validator::rules::UniqueItemsRule::new().message(#error_message);
598 if let Err(mut e) = rule.validate(val, ctx) {
599 for err in e.errors {
600 errors.add_field_error(#field_name_str, err);
601 }
602 }
603 })
604 },
605
606 _ => None
607 }
608}
609
610fn quote_option_usize(opt: &Option<usize>) -> proc_macro2::TokenStream {
611 match opt {
612 Some(v) => quote!(Some(#v as usize)),
613 None => quote!(None::<usize>),
614 }
615}
616
617fn quote_option_string(opt: &Option<String>) -> proc_macro2::TokenStream {
618 match opt {
619 Some(v) => quote!(Some(#v)),
620 None => quote!(None::<String>),
621 }
622}
623
624fn is_option(ty: &syn::Type) -> bool {
625 if let syn::Type::Path(type_path) = ty
626 && let Some(segment) = type_path.path.segments.last()
627 && segment.ident == "Option"
628 {
629 return true;
630 }
631 false
632}