Skip to main content

zod_rs_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5#[proc_macro_derive(ZodSchema, attributes(zod))]
6pub fn derive_zod_schema(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9
10    match &input.data {
11        Data::Struct(data_struct) => match &data_struct.fields {
12            Fields::Named(fields) => {
13                let field_validations = fields.named.iter().map(|field| {
14                    let field_name = &field.ident;
15                    let field_name_str = field_name.as_ref().unwrap().to_string();
16                    let field_type = &field.ty;
17                    let field_attrs = &field.attrs;
18
19                    generate_field_validation_with_attrs(&field_name_str, field_type, field_attrs)
20                });
21
22                let expanded = quote! {
23                    impl #name {
24                        pub fn schema() -> impl zod_rs::Schema<serde_json::Value> {
25                            zod_rs::object()
26                                #(#field_validations)*
27                        }
28
29                        pub fn validate_and_parse(value: &serde_json::Value) -> Result<Self, ::zod_rs::__private::ValidationResult> {
30                            match Self::schema().validate(value) {
31                                Ok(_) => {
32                                    serde_json::from_value(value.clone())
33                                        .map_err(|e| ::zod_rs::__private::ValidationError::custom(format!("Deserialization failed: {}", e)).into())
34                                }
35                                Err(validation_result) => Err(validation_result)
36                            }
37                        }
38
39                        pub fn from_json(json_str: &str) -> Result<Self, ::zod_rs::__private::ParseError> {
40                            let value: serde_json::Value = serde_json::from_str(json_str)?;
41                            Ok(Self::validate_and_parse(&value)?)
42                        }
43
44                        pub fn validate_json(json_str: &str) -> Result<serde_json::Value, ::zod_rs::__private::ParseError> {
45                            let value: serde_json::Value = serde_json::from_str(json_str)?;
46                            Self::schema().validate(&value)?;
47                            Ok(value)
48                        }
49                    }
50                };
51
52                TokenStream::from(expanded)
53            }
54            Fields::Unnamed(_) => {
55                let error = syn::Error::new_spanned(
56                    &input,
57                    "ZodSchema can only be derived for structs with named fields, not tuple structs",
58                );
59                TokenStream::from(error.to_compile_error())
60            }
61            Fields::Unit => {
62                let error = syn::Error::new_spanned(
63                    &input,
64                    "ZodSchema can only be derived for structs with named fields, not unit structs",
65                );
66                TokenStream::from(error.to_compile_error())
67            }
68        },
69        Data::Enum(data_enum) => generate_enum_schema(name, data_enum),
70        Data::Union(_) => {
71            let error = syn::Error::new_spanned(
72                &input,
73                "ZodSchema cannot be derived for unions",
74            );
75            TokenStream::from(error.to_compile_error())
76        }
77    }
78}
79
80#[derive(Default)]
81struct ZodAttributes {
82    min: Option<f64>,
83    max: Option<f64>,
84    length: Option<usize>,
85    min_length: Option<usize>,
86    max_length: Option<usize>,
87    starts_with: Option<String>,
88    ends_with: Option<String>,
89    includes: Option<String>,
90    email: bool,
91    url: bool,
92    regex: Option<String>,
93    positive: bool,
94    negative: bool,
95    nonnegative: bool,
96    nonpositive: bool,
97    int: bool,
98    finite: bool,
99}
100
101fn parse_zod_attributes(attrs: &[Attribute]) -> ZodAttributes {
102    let mut zod_attrs = ZodAttributes::default();
103
104    for attr in attrs {
105        if attr.path().is_ident("zod") {
106            if let Meta::List(meta_list) = &attr.meta {
107                let tokens: Vec<_> = meta_list.tokens.clone().into_iter().collect();
108                let mut i = 0;
109
110                while i < tokens.len() {
111                    let token_str = tokens[i].to_string();
112
113                    match token_str.as_str() {
114                        "min_length" => {
115                            if i + 1 < tokens.len() {
116                                let value_token = tokens[i + 1].to_string();
117                                if let Some(value) = extract_number_from_parens(&value_token) {
118                                    zod_attrs.min_length = Some(value);
119                                }
120                                i += 1; // Skip the value token
121                            }
122                        }
123                        "max_length" => {
124                            if i + 1 < tokens.len() {
125                                let value_token = tokens[i + 1].to_string();
126                                if let Some(value) = extract_number_from_parens(&value_token) {
127                                    zod_attrs.max_length = Some(value);
128                                }
129                                i += 1;
130                            }
131                        }
132                        "length" => {
133                            if i + 1 < tokens.len() {
134                                let value_token = tokens[i + 1].to_string();
135                                if let Some(value) = extract_number_from_parens(&value_token) {
136                                    zod_attrs.length = Some(value);
137                                }
138                                i += 1;
139                            }
140                        }
141                        "min" => {
142                            if i + 1 < tokens.len() {
143                                let value_token = tokens[i + 1].to_string();
144                                if let Some(value_str) = extract_string_from_parens(&value_token) {
145                                    if let Ok(value) = value_str.parse::<f64>() {
146                                        zod_attrs.min = Some(value);
147                                    }
148                                }
149                                i += 1;
150                            }
151                        }
152                        "max" => {
153                            if i + 1 < tokens.len() {
154                                let value_token = tokens[i + 1].to_string();
155                                if let Some(value_str) = extract_string_from_parens(&value_token) {
156                                    if let Ok(value) = value_str.parse::<f64>() {
157                                        zod_attrs.max = Some(value);
158                                    }
159                                }
160                                i += 1;
161                            }
162                        }
163                        "starts_with" => {
164                            if i + 1 < tokens.len() {
165                                let value_token = tokens[i + 1].to_string();
166                                if let Some(value) = extract_string_from_parens(&value_token) {
167                                    zod_attrs.starts_with = Some(strip_quotes(&value));
168                                }
169                                i += 1;
170                            }
171                        }
172                        "ends_with" => {
173                            if i + 1 < tokens.len() {
174                                let value_token = tokens[i + 1].to_string();
175                                if let Some(value) = extract_string_from_parens(&value_token) {
176                                    zod_attrs.ends_with = Some(strip_quotes(&value));
177                                }
178                                i += 1;
179                            }
180                        }
181                        "includes" => {
182                            if i + 1 < tokens.len() {
183                                let value_token = tokens[i + 1].to_string();
184                                if let Some(value) = extract_string_from_parens(&value_token) {
185                                    zod_attrs.includes = Some(strip_quotes(&value));
186                                }
187                                i += 1;
188                            }
189                        }
190                        "regex" => {
191                            if i + 1 < tokens.len() {
192                                let value_token = tokens[i + 1].to_string();
193                                if let Some(value) = extract_string_from_parens(&value_token) {
194                                    zod_attrs.regex = Some(strip_quotes(&value));
195                                }
196                                i += 1;
197                            }
198                        }
199                        "email" => {
200                            zod_attrs.email = true;
201                        }
202                        "url" => {
203                            zod_attrs.url = true;
204                        }
205                        "positive" => {
206                            zod_attrs.positive = true;
207                        }
208                        "negative" => {
209                            zod_attrs.negative = true;
210                        }
211                        "nonnegative" => {
212                            zod_attrs.nonnegative = true;
213                        }
214                        "nonpositive" => {
215                            zod_attrs.nonpositive = true;
216                        }
217                        "int" => {
218                            zod_attrs.int = true;
219                        }
220                        "finite" => {
221                            zod_attrs.finite = true;
222                        }
223                        "," => {
224                            // Skip commas
225                        }
226                        _ => {
227                            // Skip unknown tokens
228                        }
229                    }
230
231                    i += 1;
232                }
233            }
234        }
235    }
236
237    zod_attrs
238}
239
240fn extract_number_from_parens(token: &str) -> Option<usize> {
241    token
242        .strip_prefix('(')
243        .and_then(|s| s.strip_suffix(')'))
244        .and_then(|inner| inner.parse::<usize>().ok())
245}
246
247fn extract_string_from_parens(token: &str) -> Option<String> {
248    token
249        .strip_prefix('(')
250        .and_then(|s| s.strip_suffix(')'))
251        .map(|s| s.to_string())
252}
253
254/// Safely removes surrounding quotes from a string value
255fn strip_quotes(value: &str) -> String {
256    // Try to strip regular quotes first
257    if let Some(inner) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
258        return inner.to_string();
259    }
260    // Try to strip raw string literal (r"...")
261    if let Some(inner) = value.strip_prefix("r\"").and_then(|s| s.strip_suffix('"')) {
262        return inner.to_string();
263    }
264    // Return as-is if no quotes
265    value.to_string()
266}
267
268fn generate_field_validation_with_attrs(
269    field_name: &str,
270    field_type: &syn::Type,
271    attrs: &[Attribute],
272) -> proc_macro2::TokenStream {
273    let zod_attrs = parse_zod_attributes(attrs);
274    let is_optional = is_option_type(field_type);
275
276    if is_optional {
277        let inner_type = get_option_inner_type(field_type);
278        let base_validation = generate_base_validation_with_attrs(&inner_type, &zod_attrs);
279        quote! { .optional_field(#field_name, #base_validation) }
280    } else {
281        let base_validation = generate_base_validation_with_attrs(field_type, &zod_attrs);
282        quote! { .field(#field_name, #base_validation) }
283    }
284}
285
286fn generate_base_validation_with_attrs(
287    field_type: &syn::Type,
288    zod_attrs: &ZodAttributes,
289) -> proc_macro2::TokenStream {
290    if let syn::Type::Path(type_path) = field_type {
291        if let Some(segment) = type_path.path.segments.last() {
292            let type_name = segment.ident.to_string();
293
294            match type_name.as_str() {
295                "String" => {
296                    let mut validation = quote! { zod_rs::string() };
297
298                    if let Some(min) = zod_attrs.min_length {
299                        validation = quote! { #validation.min(#min) };
300                    }
301                    if let Some(max) = zod_attrs.max_length {
302                        validation = quote! { #validation.max(#max) };
303                    }
304                    if let Some(length) = zod_attrs.length {
305                        validation = quote! { #validation.length(#length) };
306                    }
307                    if zod_attrs.email {
308                        validation = quote! { #validation.email() };
309                    }
310                    if zod_attrs.url {
311                        validation = quote! { #validation.url() };
312                    }
313                    if let Some(regex) = &zod_attrs.regex {
314                        validation = quote! { #validation.regex(#regex) };
315                    }
316                    if let Some(starts_with) = &zod_attrs.starts_with {
317                        validation = quote! { #validation.starts_with(#starts_with) };
318                    }
319                    if let Some(ends_with) = &zod_attrs.ends_with {
320                        validation = quote! { #validation.ends_with(#ends_with) };
321                    }
322                    if let Some(includes) = &zod_attrs.includes {
323                        validation = quote! { #validation.includes(#includes) };
324                    }
325
326                    validation
327                }
328                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize"
329                | "f32" | "f64" => {
330                    let mut validation = quote! { zod_rs::number() };
331
332                    if zod_attrs.int
333                        || matches!(
334                            type_name.as_str(),
335                            "i8" | "i16"
336                                | "i32"
337                                | "i64"
338                                | "u8"
339                                | "u16"
340                                | "u32"
341                                | "u64"
342                                | "isize"
343                                | "usize"
344                        )
345                    {
346                        validation = quote! { #validation.int() };
347                    }
348                    if let Some(min) = zod_attrs.min {
349                        validation = quote! { #validation.min(#min) };
350                    }
351                    if let Some(max) = zod_attrs.max {
352                        validation = quote! { #validation.max(#max) };
353                    }
354                    if zod_attrs.positive {
355                        validation = quote! { #validation.positive() };
356                    }
357                    if zod_attrs.negative {
358                        validation = quote! { #validation.negative() };
359                    }
360                    if zod_attrs.nonnegative {
361                        validation = quote! { #validation.nonnegative() };
362                    }
363                    if zod_attrs.nonpositive {
364                        validation = quote! { #validation.nonpositive() };
365                    }
366                    if zod_attrs.finite {
367                        validation = quote! { #validation.finite() };
368                    }
369
370                    validation
371                }
372                "bool" => {
373                    quote! { zod_rs::boolean() }
374                }
375                "Vec" => {
376                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
377                        if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
378                            let inner_validation = generate_element_validation(inner_type);
379                            let mut validation = quote! { zod_rs::array(#inner_validation) };
380
381                            if let Some(min) = zod_attrs.min_length {
382                                validation = quote! { #validation.min(#min) };
383                            }
384                            if let Some(max) = zod_attrs.max_length {
385                                validation = quote! { #validation.max(#max) };
386                            }
387                            if let Some(length) = zod_attrs.length {
388                                validation = quote! { #validation.length(#length) };
389                            }
390
391                            validation
392                        } else {
393                            quote! { zod_rs::array(zod_rs::string()) }
394                        }
395                    } else {
396                        quote! { zod_rs::array(zod_rs::string()) }
397                    }
398                }
399                _ => {
400                    let type_ident = &segment.ident;
401                    quote! { #type_ident::schema() }
402                }
403            }
404        } else {
405            quote! { zod_rs::string() }
406        }
407    } else {
408        quote! { zod_rs::string() }
409    }
410}
411
412fn generate_element_validation(field_type: &syn::Type) -> proc_macro2::TokenStream {
413    if let syn::Type::Path(type_path) = field_type {
414        if let Some(segment) = type_path.path.segments.last() {
415            let type_name = segment.ident.to_string();
416
417            match type_name.as_str() {
418                "String" => quote! { zod_rs::string() },
419                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize" => {
420                    quote! { zod_rs::number().int() }
421                }
422                "f32" | "f64" => quote! { zod_rs::number() },
423                "bool" => quote! { zod_rs::boolean() },
424                _ => {
425                    let type_ident = &segment.ident;
426                    quote! { #type_ident::schema() }
427                }
428            }
429        } else {
430            quote! { zod_rs::string() }
431        }
432    } else {
433        quote! { zod_rs::string() }
434    }
435}
436
437fn is_option_type(ty: &syn::Type) -> bool {
438    if let syn::Type::Path(type_path) = ty {
439        if let Some(segment) = type_path.path.segments.last() {
440            return segment.ident == "Option";
441        }
442    }
443    false
444}
445
446fn get_option_inner_type(ty: &syn::Type) -> syn::Type {
447    if let syn::Type::Path(type_path) = ty {
448        if let Some(segment) = type_path.path.segments.last() {
449            if segment.ident == "Option" {
450                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
451                    if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
452                        return inner_type.clone();
453                    }
454                }
455            }
456        }
457    }
458    syn::parse_quote! { String }
459}
460
461fn generate_enum_schema(name: &syn::Ident, data_enum: &syn::DataEnum) -> TokenStream {
462    let variant_schemas = data_enum.variants.iter().map(|variant| {
463        let variant_name = &variant.ident;
464        let variant_name_str = variant_name.to_string();
465
466        generate_variant_schema(&variant_name_str, &variant.fields)
467    });
468
469    let expanded = quote! {
470        impl #name {
471            pub fn schema() -> impl zod_rs::Schema<serde_json::Value> {
472                zod_rs::union()
473                    #(#variant_schemas)*
474            }
475
476            pub fn validate_and_parse(value: &serde_json::Value) -> Result<Self, ::zod_rs::__private::ValidationResult> {
477                match Self::schema().validate(value) {
478                    Ok(_) => {
479                        serde_json::from_value(value.clone())
480                            .map_err(|e| ::zod_rs::__private::ValidationError::custom(format!("Deserialization failed: {}", e)).into())
481                    }
482                    Err(validation_result) => Err(validation_result)
483                }
484            }
485
486            pub fn from_json(json_str: &str) -> Result<Self, ::zod_rs::__private::ParseError> {
487                let value: serde_json::Value = serde_json::from_str(json_str)?;
488                Ok(Self::validate_and_parse(&value)?)
489            }
490
491            pub fn validate_json(json_str: &str) -> Result<serde_json::Value, ::zod_rs::__private::ParseError> {
492                let value: serde_json::Value = serde_json::from_str(json_str)?;
493                Self::schema().validate(&value)?;
494                Ok(value)
495            }
496        }
497    };
498
499    TokenStream::from(expanded)
500}
501
502fn generate_variant_schema(variant_name: &str, fields: &Fields) -> proc_macro2::TokenStream {
503    match fields {
504        // Unit variant: {"VariantName": null}
505        Fields::Unit => {
506            quote! {
507                .variant(
508                    zod_rs::object()
509                        .field(#variant_name, zod_rs::null())
510                )
511            }
512        }
513
514        // Tuple variant (unnamed fields)
515        Fields::Unnamed(fields_unnamed) => {
516            generate_tuple_variant_schema(variant_name, fields_unnamed)
517        }
518
519        // Struct variant (named fields): {"VariantName": {"field1": ..., "field2": ...}}
520        Fields::Named(fields_named) => {
521            generate_struct_variant_schema(variant_name, fields_named)
522        }
523    }
524}
525
526fn generate_tuple_variant_schema(
527    variant_name: &str,
528    fields: &syn::FieldsUnnamed,
529) -> proc_macro2::TokenStream {
530    let field_count = fields.unnamed.len();
531
532    if field_count == 1 {
533        // Single element: {"VariantName": value}
534        let field = fields.unnamed.first().unwrap();
535        let field_type = &field.ty;
536        let field_attrs = &field.attrs;
537        let inner_validation =
538            generate_base_validation_with_attrs(field_type, &parse_zod_attributes(field_attrs));
539
540        quote! {
541            .variant(
542                zod_rs::object()
543                    .field(#variant_name, #inner_validation)
544            )
545        }
546    } else {
547        // Multiple elements: {"VariantName": [value1, value2, ...]}
548        let element_validations = fields.unnamed.iter().map(|field| {
549            let field_type = &field.ty;
550            let field_attrs = &field.attrs;
551            generate_base_validation_with_attrs(field_type, &parse_zod_attributes(field_attrs))
552        });
553
554        quote! {
555            .variant(
556                zod_rs::object()
557                    .field(#variant_name, zod_rs::tuple()
558                        #(.element(#element_validations))*
559                    )
560            )
561        }
562    }
563}
564
565fn generate_struct_variant_schema(
566    variant_name: &str,
567    fields: &syn::FieldsNamed,
568) -> proc_macro2::TokenStream {
569    let field_validations = fields.named.iter().map(|field| {
570        let field_name = &field.ident;
571        let field_name_str = field_name.as_ref().unwrap().to_string();
572        let field_type = &field.ty;
573        let field_attrs = &field.attrs;
574
575        generate_field_validation_with_attrs(&field_name_str, field_type, field_attrs)
576    });
577
578    quote! {
579        .variant(
580            zod_rs::object()
581                .field(#variant_name, zod_rs::object()
582                    #(#field_validations)*
583                )
584        )
585    }
586}
587
588#[proc_macro]
589pub fn infer_struct(_input: TokenStream) -> TokenStream {
590    let expanded = quote! {
591        compile_error!("infer_struct macro is not yet implemented. Use #[derive(ZodSchema)] instead.");
592    };
593
594    TokenStream::from(expanded)
595}