1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5#[proc_macro_derive(ZodSchema, attributes(zod))]
6pub fn derive_zod_schema(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = &input.ident;
9
10 match &input.data {
11 Data::Struct(data_struct) => match &data_struct.fields {
12 Fields::Named(fields) => {
13 let field_validations = fields.named.iter().map(|field| {
14 let field_name = &field.ident;
15 let field_name_str = field_name.as_ref().unwrap().to_string();
16 let field_type = &field.ty;
17 let field_attrs = &field.attrs;
18
19 generate_field_validation_with_attrs(&field_name_str, field_type, field_attrs)
20 });
21
22 let expanded = quote! {
23 impl #name {
24 pub fn schema() -> impl zod_rs::Schema<serde_json::Value> {
25 zod_rs::object()
26 #(#field_validations)*
27 }
28
29 pub fn validate_and_parse(value: &serde_json::Value) -> Result<Self, ::zod_rs::__private::ValidationResult> {
30 match Self::schema().validate(value) {
31 Ok(_) => {
32 serde_json::from_value(value.clone())
33 .map_err(|e| ::zod_rs::__private::ValidationError::custom(format!("Deserialization failed: {}", e)).into())
34 }
35 Err(validation_result) => Err(validation_result)
36 }
37 }
38
39 pub fn from_json(json_str: &str) -> Result<Self, ::zod_rs::__private::ParseError> {
40 let value: serde_json::Value = serde_json::from_str(json_str)?;
41 Ok(Self::validate_and_parse(&value)?)
42 }
43
44 pub fn validate_json(json_str: &str) -> Result<serde_json::Value, ::zod_rs::__private::ParseError> {
45 let value: serde_json::Value = serde_json::from_str(json_str)?;
46 Self::schema().validate(&value)?;
47 Ok(value)
48 }
49 }
50 };
51
52 TokenStream::from(expanded)
53 }
54 Fields::Unnamed(_) => {
55 let error = syn::Error::new_spanned(
56 &input,
57 "ZodSchema can only be derived for structs with named fields, not tuple structs",
58 );
59 TokenStream::from(error.to_compile_error())
60 }
61 Fields::Unit => {
62 let error = syn::Error::new_spanned(
63 &input,
64 "ZodSchema can only be derived for structs with named fields, not unit structs",
65 );
66 TokenStream::from(error.to_compile_error())
67 }
68 },
69 Data::Enum(data_enum) => generate_enum_schema(name, data_enum),
70 Data::Union(_) => {
71 let error = syn::Error::new_spanned(
72 &input,
73 "ZodSchema cannot be derived for unions",
74 );
75 TokenStream::from(error.to_compile_error())
76 }
77 }
78}
79
80#[derive(Default)]
81struct ZodAttributes {
82 min: Option<f64>,
83 max: Option<f64>,
84 length: Option<usize>,
85 min_length: Option<usize>,
86 max_length: Option<usize>,
87 starts_with: Option<String>,
88 ends_with: Option<String>,
89 includes: Option<String>,
90 email: bool,
91 url: bool,
92 regex: Option<String>,
93 positive: bool,
94 negative: bool,
95 nonnegative: bool,
96 nonpositive: bool,
97 int: bool,
98 finite: bool,
99}
100
101fn parse_zod_attributes(attrs: &[Attribute]) -> ZodAttributes {
102 let mut zod_attrs = ZodAttributes::default();
103
104 for attr in attrs {
105 if attr.path().is_ident("zod") {
106 if let Meta::List(meta_list) = &attr.meta {
107 let tokens: Vec<_> = meta_list.tokens.clone().into_iter().collect();
108 let mut i = 0;
109
110 while i < tokens.len() {
111 let token_str = tokens[i].to_string();
112
113 match token_str.as_str() {
114 "min_length" => {
115 if i + 1 < tokens.len() {
116 let value_token = tokens[i + 1].to_string();
117 if let Some(value) = extract_number_from_parens(&value_token) {
118 zod_attrs.min_length = Some(value);
119 }
120 i += 1; }
122 }
123 "max_length" => {
124 if i + 1 < tokens.len() {
125 let value_token = tokens[i + 1].to_string();
126 if let Some(value) = extract_number_from_parens(&value_token) {
127 zod_attrs.max_length = Some(value);
128 }
129 i += 1;
130 }
131 }
132 "length" => {
133 if i + 1 < tokens.len() {
134 let value_token = tokens[i + 1].to_string();
135 if let Some(value) = extract_number_from_parens(&value_token) {
136 zod_attrs.length = Some(value);
137 }
138 i += 1;
139 }
140 }
141 "min" => {
142 if i + 1 < tokens.len() {
143 let value_token = tokens[i + 1].to_string();
144 if let Some(value_str) = extract_string_from_parens(&value_token) {
145 if let Ok(value) = value_str.parse::<f64>() {
146 zod_attrs.min = Some(value);
147 }
148 }
149 i += 1;
150 }
151 }
152 "max" => {
153 if i + 1 < tokens.len() {
154 let value_token = tokens[i + 1].to_string();
155 if let Some(value_str) = extract_string_from_parens(&value_token) {
156 if let Ok(value) = value_str.parse::<f64>() {
157 zod_attrs.max = Some(value);
158 }
159 }
160 i += 1;
161 }
162 }
163 "starts_with" => {
164 if i + 1 < tokens.len() {
165 let value_token = tokens[i + 1].to_string();
166 if let Some(value) = extract_string_from_parens(&value_token) {
167 zod_attrs.starts_with = Some(strip_quotes(&value));
168 }
169 i += 1;
170 }
171 }
172 "ends_with" => {
173 if i + 1 < tokens.len() {
174 let value_token = tokens[i + 1].to_string();
175 if let Some(value) = extract_string_from_parens(&value_token) {
176 zod_attrs.ends_with = Some(strip_quotes(&value));
177 }
178 i += 1;
179 }
180 }
181 "includes" => {
182 if i + 1 < tokens.len() {
183 let value_token = tokens[i + 1].to_string();
184 if let Some(value) = extract_string_from_parens(&value_token) {
185 zod_attrs.includes = Some(strip_quotes(&value));
186 }
187 i += 1;
188 }
189 }
190 "regex" => {
191 if i + 1 < tokens.len() {
192 let value_token = tokens[i + 1].to_string();
193 if let Some(value) = extract_string_from_parens(&value_token) {
194 zod_attrs.regex = Some(strip_quotes(&value));
195 }
196 i += 1;
197 }
198 }
199 "email" => {
200 zod_attrs.email = true;
201 }
202 "url" => {
203 zod_attrs.url = true;
204 }
205 "positive" => {
206 zod_attrs.positive = true;
207 }
208 "negative" => {
209 zod_attrs.negative = true;
210 }
211 "nonnegative" => {
212 zod_attrs.nonnegative = true;
213 }
214 "nonpositive" => {
215 zod_attrs.nonpositive = true;
216 }
217 "int" => {
218 zod_attrs.int = true;
219 }
220 "finite" => {
221 zod_attrs.finite = true;
222 }
223 "," => {
224 }
226 _ => {
227 }
229 }
230
231 i += 1;
232 }
233 }
234 }
235 }
236
237 zod_attrs
238}
239
240fn extract_number_from_parens(token: &str) -> Option<usize> {
241 token
242 .strip_prefix('(')
243 .and_then(|s| s.strip_suffix(')'))
244 .and_then(|inner| inner.parse::<usize>().ok())
245}
246
247fn extract_string_from_parens(token: &str) -> Option<String> {
248 token
249 .strip_prefix('(')
250 .and_then(|s| s.strip_suffix(')'))
251 .map(|s| s.to_string())
252}
253
254fn strip_quotes(value: &str) -> String {
256 if let Some(inner) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
258 return inner.to_string();
259 }
260 if let Some(inner) = value.strip_prefix("r\"").and_then(|s| s.strip_suffix('"')) {
262 return inner.to_string();
263 }
264 value.to_string()
266}
267
268fn generate_field_validation_with_attrs(
269 field_name: &str,
270 field_type: &syn::Type,
271 attrs: &[Attribute],
272) -> proc_macro2::TokenStream {
273 let zod_attrs = parse_zod_attributes(attrs);
274 let is_optional = is_option_type(field_type);
275
276 if is_optional {
277 let inner_type = get_option_inner_type(field_type);
278 let base_validation = generate_base_validation_with_attrs(&inner_type, &zod_attrs);
279 quote! { .optional_field(#field_name, #base_validation) }
280 } else {
281 let base_validation = generate_base_validation_with_attrs(field_type, &zod_attrs);
282 quote! { .field(#field_name, #base_validation) }
283 }
284}
285
286fn generate_base_validation_with_attrs(
287 field_type: &syn::Type,
288 zod_attrs: &ZodAttributes,
289) -> proc_macro2::TokenStream {
290 if let syn::Type::Path(type_path) = field_type {
291 if let Some(segment) = type_path.path.segments.last() {
292 let type_name = segment.ident.to_string();
293
294 match type_name.as_str() {
295 "String" => {
296 let mut validation = quote! { zod_rs::string() };
297
298 if let Some(min) = zod_attrs.min_length {
299 validation = quote! { #validation.min(#min) };
300 }
301 if let Some(max) = zod_attrs.max_length {
302 validation = quote! { #validation.max(#max) };
303 }
304 if let Some(length) = zod_attrs.length {
305 validation = quote! { #validation.length(#length) };
306 }
307 if zod_attrs.email {
308 validation = quote! { #validation.email() };
309 }
310 if zod_attrs.url {
311 validation = quote! { #validation.url() };
312 }
313 if let Some(regex) = &zod_attrs.regex {
314 validation = quote! { #validation.regex(#regex) };
315 }
316 if let Some(starts_with) = &zod_attrs.starts_with {
317 validation = quote! { #validation.starts_with(#starts_with) };
318 }
319 if let Some(ends_with) = &zod_attrs.ends_with {
320 validation = quote! { #validation.ends_with(#ends_with) };
321 }
322 if let Some(includes) = &zod_attrs.includes {
323 validation = quote! { #validation.includes(#includes) };
324 }
325
326 validation
327 }
328 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize"
329 | "f32" | "f64" => {
330 let mut validation = quote! { zod_rs::number() };
331
332 if zod_attrs.int
333 || matches!(
334 type_name.as_str(),
335 "i8" | "i16"
336 | "i32"
337 | "i64"
338 | "u8"
339 | "u16"
340 | "u32"
341 | "u64"
342 | "isize"
343 | "usize"
344 )
345 {
346 validation = quote! { #validation.int() };
347 }
348 if let Some(min) = zod_attrs.min {
349 validation = quote! { #validation.min(#min) };
350 }
351 if let Some(max) = zod_attrs.max {
352 validation = quote! { #validation.max(#max) };
353 }
354 if zod_attrs.positive {
355 validation = quote! { #validation.positive() };
356 }
357 if zod_attrs.negative {
358 validation = quote! { #validation.negative() };
359 }
360 if zod_attrs.nonnegative {
361 validation = quote! { #validation.nonnegative() };
362 }
363 if zod_attrs.nonpositive {
364 validation = quote! { #validation.nonpositive() };
365 }
366 if zod_attrs.finite {
367 validation = quote! { #validation.finite() };
368 }
369
370 validation
371 }
372 "bool" => {
373 quote! { zod_rs::boolean() }
374 }
375 "Vec" => {
376 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
377 if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
378 let inner_validation = generate_element_validation(inner_type);
379 let mut validation = quote! { zod_rs::array(#inner_validation) };
380
381 if let Some(min) = zod_attrs.min_length {
382 validation = quote! { #validation.min(#min) };
383 }
384 if let Some(max) = zod_attrs.max_length {
385 validation = quote! { #validation.max(#max) };
386 }
387 if let Some(length) = zod_attrs.length {
388 validation = quote! { #validation.length(#length) };
389 }
390
391 validation
392 } else {
393 quote! { zod_rs::array(zod_rs::string()) }
394 }
395 } else {
396 quote! { zod_rs::array(zod_rs::string()) }
397 }
398 }
399 _ => {
400 let type_ident = &segment.ident;
401 quote! { #type_ident::schema() }
402 }
403 }
404 } else {
405 quote! { zod_rs::string() }
406 }
407 } else {
408 quote! { zod_rs::string() }
409 }
410}
411
412fn generate_element_validation(field_type: &syn::Type) -> proc_macro2::TokenStream {
413 if let syn::Type::Path(type_path) = field_type {
414 if let Some(segment) = type_path.path.segments.last() {
415 let type_name = segment.ident.to_string();
416
417 match type_name.as_str() {
418 "String" => quote! { zod_rs::string() },
419 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize" => {
420 quote! { zod_rs::number().int() }
421 }
422 "f32" | "f64" => quote! { zod_rs::number() },
423 "bool" => quote! { zod_rs::boolean() },
424 _ => {
425 let type_ident = &segment.ident;
426 quote! { #type_ident::schema() }
427 }
428 }
429 } else {
430 quote! { zod_rs::string() }
431 }
432 } else {
433 quote! { zod_rs::string() }
434 }
435}
436
437fn is_option_type(ty: &syn::Type) -> bool {
438 if let syn::Type::Path(type_path) = ty {
439 if let Some(segment) = type_path.path.segments.last() {
440 return segment.ident == "Option";
441 }
442 }
443 false
444}
445
446fn get_option_inner_type(ty: &syn::Type) -> syn::Type {
447 if let syn::Type::Path(type_path) = ty {
448 if let Some(segment) = type_path.path.segments.last() {
449 if segment.ident == "Option" {
450 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
451 if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
452 return inner_type.clone();
453 }
454 }
455 }
456 }
457 }
458 syn::parse_quote! { String }
459}
460
461fn generate_enum_schema(name: &syn::Ident, data_enum: &syn::DataEnum) -> TokenStream {
462 let variant_schemas = data_enum.variants.iter().map(|variant| {
463 let variant_name = &variant.ident;
464 let variant_name_str = variant_name.to_string();
465
466 generate_variant_schema(&variant_name_str, &variant.fields)
467 });
468
469 let expanded = quote! {
470 impl #name {
471 pub fn schema() -> impl zod_rs::Schema<serde_json::Value> {
472 zod_rs::union()
473 #(#variant_schemas)*
474 }
475
476 pub fn validate_and_parse(value: &serde_json::Value) -> Result<Self, ::zod_rs::__private::ValidationResult> {
477 match Self::schema().validate(value) {
478 Ok(_) => {
479 serde_json::from_value(value.clone())
480 .map_err(|e| ::zod_rs::__private::ValidationError::custom(format!("Deserialization failed: {}", e)).into())
481 }
482 Err(validation_result) => Err(validation_result)
483 }
484 }
485
486 pub fn from_json(json_str: &str) -> Result<Self, ::zod_rs::__private::ParseError> {
487 let value: serde_json::Value = serde_json::from_str(json_str)?;
488 Ok(Self::validate_and_parse(&value)?)
489 }
490
491 pub fn validate_json(json_str: &str) -> Result<serde_json::Value, ::zod_rs::__private::ParseError> {
492 let value: serde_json::Value = serde_json::from_str(json_str)?;
493 Self::schema().validate(&value)?;
494 Ok(value)
495 }
496 }
497 };
498
499 TokenStream::from(expanded)
500}
501
502fn generate_variant_schema(variant_name: &str, fields: &Fields) -> proc_macro2::TokenStream {
503 match fields {
504 Fields::Unit => {
506 quote! {
507 .variant(
508 zod_rs::object()
509 .field(#variant_name, zod_rs::null())
510 )
511 }
512 }
513
514 Fields::Unnamed(fields_unnamed) => {
516 generate_tuple_variant_schema(variant_name, fields_unnamed)
517 }
518
519 Fields::Named(fields_named) => {
521 generate_struct_variant_schema(variant_name, fields_named)
522 }
523 }
524}
525
526fn generate_tuple_variant_schema(
527 variant_name: &str,
528 fields: &syn::FieldsUnnamed,
529) -> proc_macro2::TokenStream {
530 let field_count = fields.unnamed.len();
531
532 if field_count == 1 {
533 let field = fields.unnamed.first().unwrap();
535 let field_type = &field.ty;
536 let field_attrs = &field.attrs;
537 let inner_validation =
538 generate_base_validation_with_attrs(field_type, &parse_zod_attributes(field_attrs));
539
540 quote! {
541 .variant(
542 zod_rs::object()
543 .field(#variant_name, #inner_validation)
544 )
545 }
546 } else {
547 let element_validations = fields.unnamed.iter().map(|field| {
549 let field_type = &field.ty;
550 let field_attrs = &field.attrs;
551 generate_base_validation_with_attrs(field_type, &parse_zod_attributes(field_attrs))
552 });
553
554 quote! {
555 .variant(
556 zod_rs::object()
557 .field(#variant_name, zod_rs::tuple()
558 #(.element(#element_validations))*
559 )
560 )
561 }
562 }
563}
564
565fn generate_struct_variant_schema(
566 variant_name: &str,
567 fields: &syn::FieldsNamed,
568) -> proc_macro2::TokenStream {
569 let field_validations = fields.named.iter().map(|field| {
570 let field_name = &field.ident;
571 let field_name_str = field_name.as_ref().unwrap().to_string();
572 let field_type = &field.ty;
573 let field_attrs = &field.attrs;
574
575 generate_field_validation_with_attrs(&field_name_str, field_type, field_attrs)
576 });
577
578 quote! {
579 .variant(
580 zod_rs::object()
581 .field(#variant_name, zod_rs::object()
582 #(#field_validations)*
583 )
584 )
585 }
586}
587
588#[proc_macro]
589pub fn infer_struct(_input: TokenStream) -> TokenStream {
590 let expanded = quote! {
591 compile_error!("infer_struct macro is not yet implemented. Use #[derive(ZodSchema)] instead.");
592 };
593
594 TokenStream::from(expanded)
595}