1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5#[proc_macro_derive(ZodTs, attributes(zod))]
6pub fn derive_zod_ts(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = &input.ident;
9 let name_str = name.to_string();
10
11 match &input.data {
12 Data::Struct(data_struct) => match &data_struct.fields {
13 Fields::Named(fields) => {
14 let field_schemas: Vec<String> = fields
15 .named
16 .iter()
17 .map(|field| {
18 let field_name = field.ident.as_ref().unwrap().to_string();
19 let field_type = &field.ty;
20 let attrs = parse_zod_attributes(&field.attrs);
21 let is_optional = is_option_type(field_type);
22
23 let base_type = if is_optional {
24 get_option_inner_type_str(field_type)
25 } else {
26 type_to_string(field_type)
27 };
28
29 let zod_type = rust_type_to_zod(&base_type, &attrs);
30 let final_type = if is_optional {
31 format!("{}.optional()", zod_type)
32 } else {
33 zod_type
34 };
35
36 format!(" {}: {}", field_name, final_type)
37 })
38 .collect();
39
40 let fields_str = field_schemas.join(",\n");
41 let schema_name = format!("{}Schema", name_str);
42
43 let ts_code = format!(
44 r#"import {{ z }} from 'zod';
45
46export const {} = z.object({{
47{}
48}});
49
50export type {} = z.infer<typeof {}>;"#,
51 schema_name, fields_str, name_str, schema_name
52 );
53
54 let expanded = quote! {
55 impl #name {
56 pub fn zod_ts() -> String {
57 #ts_code.to_string()
58 }
59 }
60 };
61
62 TokenStream::from(expanded)
63 }
64 _ => {
65 let error = syn::Error::new_spanned(
66 &input,
67 "ZodTs can only be derived for structs with named fields",
68 );
69 TokenStream::from(error.to_compile_error())
70 }
71 },
72 Data::Enum(data_enum) => {
73 let variant_schemas: Vec<String> = data_enum
74 .variants
75 .iter()
76 .map(|variant| {
77 let variant_name = variant.ident.to_string();
78 generate_variant_ts(&variant_name, &variant.fields)
79 })
80 .collect();
81
82 let variants_str = variant_schemas.join(",\n ");
83 let schema_name = format!("{}Schema", name_str);
84
85 let ts_code = format!(
86 r#"import {{ z }} from 'zod';
87
88export const {} = z.union([
89 {}
90]);
91
92export type {} = z.infer<typeof {}>;"#,
93 schema_name, variants_str, name_str, schema_name
94 );
95
96 let expanded = quote! {
97 impl #name {
98 pub fn zod_ts() -> String {
99 #ts_code.to_string()
100 }
101 }
102 };
103
104 TokenStream::from(expanded)
105 }
106 Data::Union(_) => {
107 let error =
108 syn::Error::new_spanned(&input, "ZodTs cannot be derived for Rust unions");
109 TokenStream::from(error.to_compile_error())
110 }
111 }
112}
113
114fn generate_variant_ts(variant_name: &str, fields: &Fields) -> String {
115 match fields {
116 Fields::Unit => {
117 format!("z.object({{ {}: z.null() }})", variant_name)
118 }
119 Fields::Unnamed(fields_unnamed) => {
120 let field_count = fields_unnamed.unnamed.len();
121 if field_count == 1 {
122 let field = fields_unnamed.unnamed.first().unwrap();
123 let field_type = type_to_string(&field.ty);
124 let attrs = parse_zod_attributes(&field.attrs);
125 let zod_type = rust_type_to_zod(&field_type, &attrs);
126 format!("z.object({{ {}: {} }})", variant_name, zod_type)
127 } else {
128 let element_types: Vec<String> = fields_unnamed
129 .unnamed
130 .iter()
131 .map(|field| {
132 let field_type = type_to_string(&field.ty);
133 let attrs = parse_zod_attributes(&field.attrs);
134 rust_type_to_zod(&field_type, &attrs)
135 })
136 .collect();
137 let tuple_str = element_types.join(", ");
138 format!("z.object({{ {}: z.tuple([{}]) }})", variant_name, tuple_str)
139 }
140 }
141 Fields::Named(fields_named) => {
142 let field_schemas: Vec<String> = fields_named
143 .named
144 .iter()
145 .map(|field| {
146 let field_name = field.ident.as_ref().unwrap().to_string();
147 let field_type = type_to_string(&field.ty);
148 let attrs = parse_zod_attributes(&field.attrs);
149 let is_optional = is_option_type(&field.ty);
150
151 let base_type = if is_optional {
152 get_option_inner_type_str(&field.ty)
153 } else {
154 field_type
155 };
156
157 let zod_type = rust_type_to_zod(&base_type, &attrs);
158 let final_type = if is_optional {
159 format!("{}.optional()", zod_type)
160 } else {
161 zod_type
162 };
163
164 format!("{}: {}", field_name, final_type)
165 })
166 .collect();
167 let fields_str = field_schemas.join(", ");
168 format!(
169 "z.object({{ {}: z.object({{ {} }}) }})",
170 variant_name, fields_str
171 )
172 }
173 }
174}
175
176#[derive(Default)]
177struct ZodAttributes {
178 min: Option<f64>,
179 max: Option<f64>,
180 length: Option<usize>,
181 min_length: Option<usize>,
182 max_length: Option<usize>,
183 starts_with: Option<String>,
184 ends_with: Option<String>,
185 includes: Option<String>,
186 email: bool,
187 url: bool,
188 regex: Option<String>,
189 positive: bool,
190 negative: bool,
191 nonnegative: bool,
192 nonpositive: bool,
193 int: bool,
194 finite: bool,
195}
196
197fn parse_zod_attributes(attrs: &[Attribute]) -> ZodAttributes {
198 let mut zod_attrs = ZodAttributes::default();
199
200 for attr in attrs {
201 if attr.path().is_ident("zod") {
202 if let Meta::List(meta_list) = &attr.meta {
203 let tokens: Vec<_> = meta_list.tokens.clone().into_iter().collect();
204 let mut i = 0;
205
206 while i < tokens.len() {
207 let token_str = tokens[i].to_string();
208
209 match token_str.as_str() {
210 "min_length" => {
211 if i + 1 < tokens.len() {
212 let value_token = tokens[i + 1].to_string();
213 if let Some(value) = extract_number_from_parens(&value_token) {
214 zod_attrs.min_length = Some(value);
215 }
216 i += 1;
217 }
218 }
219 "max_length" => {
220 if i + 1 < tokens.len() {
221 let value_token = tokens[i + 1].to_string();
222 if let Some(value) = extract_number_from_parens(&value_token) {
223 zod_attrs.max_length = Some(value);
224 }
225 i += 1;
226 }
227 }
228 "length" => {
229 if i + 1 < tokens.len() {
230 let value_token = tokens[i + 1].to_string();
231 if let Some(value) = extract_number_from_parens(&value_token) {
232 zod_attrs.length = Some(value);
233 }
234 i += 1;
235 }
236 }
237 "min" => {
238 if i + 1 < tokens.len() {
239 let value_token = tokens[i + 1].to_string();
240 if let Some(value_str) = extract_string_from_parens(&value_token) {
241 if let Ok(value) = value_str.parse::<f64>() {
242 zod_attrs.min = Some(value);
243 }
244 }
245 i += 1;
246 }
247 }
248 "max" => {
249 if i + 1 < tokens.len() {
250 let value_token = tokens[i + 1].to_string();
251 if let Some(value_str) = extract_string_from_parens(&value_token) {
252 if let Ok(value) = value_str.parse::<f64>() {
253 zod_attrs.max = Some(value);
254 }
255 }
256 i += 1;
257 }
258 }
259 "starts_with" => {
260 if i + 1 < tokens.len() {
261 let value_token = tokens[i + 1].to_string();
262 if let Some(value) = extract_string_from_parens(&value_token) {
263 zod_attrs.starts_with = Some(strip_quotes(&value));
264 }
265 i += 1;
266 }
267 }
268 "ends_with" => {
269 if i + 1 < tokens.len() {
270 let value_token = tokens[i + 1].to_string();
271 if let Some(value) = extract_string_from_parens(&value_token) {
272 zod_attrs.ends_with = Some(strip_quotes(&value));
273 }
274 i += 1;
275 }
276 }
277 "includes" => {
278 if i + 1 < tokens.len() {
279 let value_token = tokens[i + 1].to_string();
280 if let Some(value) = extract_string_from_parens(&value_token) {
281 zod_attrs.includes = Some(strip_quotes(&value));
282 }
283 i += 1;
284 }
285 }
286 "regex" => {
287 if i + 1 < tokens.len() {
288 let value_token = tokens[i + 1].to_string();
289 if let Some(value) = extract_string_from_parens(&value_token) {
290 zod_attrs.regex = Some(strip_quotes(&value));
291 }
292 i += 1;
293 }
294 }
295 "email" => {
296 zod_attrs.email = true;
297 }
298 "url" => {
299 zod_attrs.url = true;
300 }
301 "positive" => {
302 zod_attrs.positive = true;
303 }
304 "negative" => {
305 zod_attrs.negative = true;
306 }
307 "nonnegative" => {
308 zod_attrs.nonnegative = true;
309 }
310 "nonpositive" => {
311 zod_attrs.nonpositive = true;
312 }
313 "int" => {
314 zod_attrs.int = true;
315 }
316 "finite" => {
317 zod_attrs.finite = true;
318 }
319 "," => {}
320 _ => {}
321 }
322
323 i += 1;
324 }
325 }
326 }
327 }
328
329 zod_attrs
330}
331
332fn extract_number_from_parens(token: &str) -> Option<usize> {
333 token
334 .strip_prefix('(')
335 .and_then(|s| s.strip_suffix(')'))
336 .and_then(|inner| inner.parse::<usize>().ok())
337}
338
339fn extract_string_from_parens(token: &str) -> Option<String> {
340 token
341 .strip_prefix('(')
342 .and_then(|s| s.strip_suffix(')'))
343 .map(|s| s.to_string())
344}
345
346fn strip_quotes(value: &str) -> String {
347 if let Some(inner) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
348 return inner.to_string();
349 }
350 if let Some(inner) = value.strip_prefix("r\"").and_then(|s| s.strip_suffix('"')) {
351 return inner.to_string();
352 }
353 value.to_string()
354}
355
356fn rust_type_to_zod(rust_type: &str, attrs: &ZodAttributes) -> String {
357 let base = match rust_type {
358 "String" | "&str" | "str" => {
359 let mut chain = String::from("z.string()");
360
361 if let Some(len) = attrs.length {
362 chain.push_str(&format!(".length({})", len));
363 }
364 if let Some(min) = attrs.min_length {
365 chain.push_str(&format!(".min({})", min));
366 }
367 if let Some(max) = attrs.max_length {
368 chain.push_str(&format!(".max({})", max));
369 }
370 if attrs.email {
371 chain.push_str(".email()");
372 }
373 if attrs.url {
374 chain.push_str(".url()");
375 }
376 if let Some(ref pattern) = attrs.regex {
377 chain.push_str(&format!(".regex(/{}/)", pattern));
378 }
379 if let Some(ref prefix) = attrs.starts_with {
380 chain.push_str(&format!(".startsWith(\"{}\")", prefix));
381 }
382 if let Some(ref suffix) = attrs.ends_with {
383 chain.push_str(&format!(".endsWith(\"{}\")", suffix));
384 }
385 if let Some(ref substr) = attrs.includes {
386 chain.push_str(&format!(".includes(\"{}\")", substr));
387 }
388
389 chain
390 }
391 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
392 | "usize" => {
393 let mut chain = String::from("z.number().int()");
394 append_number_validators(&mut chain, attrs);
395 chain
396 }
397 "f32" | "f64" => {
398 let mut chain = String::from("z.number()");
399 if attrs.int {
400 chain.push_str(".int()");
401 }
402 append_number_validators(&mut chain, attrs);
403 chain
404 }
405 "bool" => String::from("z.boolean()"),
406 other => {
407 if other.starts_with("Vec<") {
408 let inner = other
409 .strip_prefix("Vec<")
410 .and_then(|s| s.strip_suffix('>'))
411 .unwrap_or("unknown");
412 let inner_zod = rust_type_to_zod(inner, &ZodAttributes::default());
413 let mut chain = format!("z.array({})", inner_zod);
414
415 if let Some(len) = attrs.length {
416 chain.push_str(&format!(".length({})", len));
417 }
418 if let Some(min) = attrs.min_length {
419 chain.push_str(&format!(".min({})", min));
420 }
421 if let Some(max) = attrs.max_length {
422 chain.push_str(&format!(".max({})", max));
423 }
424
425 chain
426 } else {
427 format!("{}Schema", other)
428 }
429 }
430 };
431
432 base
433}
434
435fn append_number_validators(chain: &mut String, attrs: &ZodAttributes) {
436 if let Some(min) = attrs.min {
437 chain.push_str(&format!(".min({})", min));
438 }
439 if let Some(max) = attrs.max {
440 chain.push_str(&format!(".max({})", max));
441 }
442 if attrs.positive {
443 chain.push_str(".positive()");
444 }
445 if attrs.negative {
446 chain.push_str(".negative()");
447 }
448 if attrs.nonnegative {
449 chain.push_str(".nonnegative()");
450 }
451 if attrs.nonpositive {
452 chain.push_str(".nonpositive()");
453 }
454 if attrs.finite {
455 chain.push_str(".finite()");
456 }
457}
458
459fn type_to_string(ty: &syn::Type) -> String {
460 if let syn::Type::Path(type_path) = ty {
461 let segments: Vec<String> = type_path
462 .path
463 .segments
464 .iter()
465 .map(|seg| {
466 let ident = seg.ident.to_string();
467 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
468 let args_str: Vec<String> = args
469 .args
470 .iter()
471 .filter_map(|arg| {
472 if let syn::GenericArgument::Type(t) = arg {
473 Some(type_to_string(t))
474 } else {
475 None
476 }
477 })
478 .collect();
479 if args_str.is_empty() {
480 ident
481 } else {
482 format!("{}<{}>", ident, args_str.join(", "))
483 }
484 } else {
485 ident
486 }
487 })
488 .collect();
489 segments.join("::")
490 } else {
491 "unknown".to_string()
492 }
493}
494
495fn is_option_type(ty: &syn::Type) -> bool {
496 if let syn::Type::Path(type_path) = ty {
497 if let Some(segment) = type_path.path.segments.last() {
498 return segment.ident == "Option";
499 }
500 }
501 false
502}
503
504fn get_option_inner_type_str(ty: &syn::Type) -> String {
505 if let syn::Type::Path(type_path) = ty {
506 if let Some(segment) = type_path.path.segments.last() {
507 if segment.ident == "Option" {
508 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
509 if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
510 return type_to_string(inner_type);
511 }
512 }
513 }
514 }
515 }
516 "unknown".to_string()
517}