Skip to main content

typescript_macros/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("readme.md")]
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{Attribute, DeriveInput, ItemFn, Lit, Type, parse_macro_input};
7
8/// 字段属性信息
9struct FieldAttributes {
10    /// 字段重命名
11    rename: Option<String>,
12    /// 是否跳过该字段
13    skip: bool,
14    /// 是否标记为可选
15    optional: bool,
16    /// 自定义类型
17    custom_type: Option<String>,
18}
19
20impl FieldAttributes {
21    /// 从字段属性中解析字段属性信息
22    fn from_attributes(attrs: &[Attribute]) -> Self {
23        let mut result = Self { rename: None, skip: false, optional: false, custom_type: None };
24
25        for attr in attrs {
26            if attr.path().is_ident("ts") {
27                attr.parse_nested_meta(|meta| {
28                    if meta.path.is_ident("rename") {
29                        let value = meta.value()?;
30                        let lit_str = value.parse::<Lit>()?;
31                        if let Lit::Str(lit_str) = lit_str {
32                            result.rename = Some(lit_str.value());
33                        }
34                    }
35                    else if meta.path.is_ident("type") {
36                        let value = meta.value()?;
37                        let lit_str = value.parse::<Lit>()?;
38                        if let Lit::Str(lit_str) = lit_str {
39                            result.custom_type = Some(lit_str.value());
40                        }
41                    }
42                    else if meta.path.is_ident("skip") {
43                        result.skip = true;
44                    }
45                    else if meta.path.is_ident("optional") {
46                        result.optional = true;
47                    }
48                    Ok(())
49                })
50                .ok();
51            }
52        }
53
54        result
55    }
56}
57
58use once_cell::sync::Lazy;
59/// 将 Rust 类型转换为 TypeScript 类型字符串
60///
61/// 支持的类型映射:
62/// - 基本类型:u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize -> number
63/// - 布尔类型:bool -> boolean
64/// - 字符串:String, &str -> string
65/// - 可选类型:Option<T> -> T | undefined
66/// - 集合类型:Vec<T>, LinkedList<T>, VecDeque<T> -> T[]
67/// - 集合类型:HashSet<T>, BTreeSet<T> -> Set<T>
68/// - 映射类型:HashMap<K, V>, BTreeMap<K, V> -> Record<K, V>
69/// - 元组类型:(A, B, C) -> [A, B, C]
70/// - 泛型类型:支持泛型参数的保留
71/// - 嵌套类型:递归处理内部类型
72/// - 空类型:() -> null
73/// - 切片类型:&[T] -> T[]
74use std::collections::HashMap;
75
76/// 基本类型映射缓存
77static BASIC_TYPE_MAP: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
78    let mut map = HashMap::new();
79    // 数字类型
80    map.insert("u8", "number");
81    map.insert("u16", "number");
82    map.insert("u32", "number");
83    map.insert("u64", "number");
84    map.insert("i8", "number");
85    map.insert("i16", "number");
86    map.insert("i32", "number");
87    map.insert("i64", "number");
88    map.insert("f32", "number");
89    map.insert("f64", "number");
90    map.insert("usize", "number");
91    map.insert("isize", "number");
92    // 布尔类型
93    map.insert("bool", "boolean");
94    // 字符串类型
95    map.insert("String", "string");
96    map.insert("str", "string");
97    map
98});
99
100/// 将 Rust 类型转换为 TypeScript 类型字符串
101///
102/// 支持的类型映射:
103/// - 基本类型:u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize -> number
104/// - 布尔类型:bool -> boolean
105/// - 字符串:String, &str -> string
106/// - 可选类型:Option<T> -> T | undefined
107/// - 集合类型:Vec<T>, LinkedList<T>, VecDeque<T> -> T[]
108/// - 集合类型:HashSet<T>, BTreeSet<T> -> Set<T>
109/// - 映射类型:HashMap<K, V>, BTreeMap<K, V> -> Record<K, V>
110/// - 元组类型:(A, B, C) -> [A, B, C]
111/// - 泛型类型:支持泛型参数的保留
112/// - 嵌套类型:递归处理内部类型
113/// - 空类型:() -> null
114/// - 切片类型:&[T] -> T[]
115fn rust_type_to_typescript(ty: &Type) -> String {
116    match ty {
117        Type::Path(type_path) => {
118            let path = &type_path.path;
119
120            if let Some(segment) = path.segments.last() {
121                let ident = &segment.ident;
122                let ident_str = ident.to_string();
123
124                // 处理基本类型(使用缓存)
125                if let Some(ts_type) = BASIC_TYPE_MAP.get(ident_str.as_str()) {
126                    return ts_type.to_string();
127                }
128
129                // 处理容器类型
130                match ident_str.as_str() {
131                    "Option" => {
132                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
133                            && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
134                        {
135                            let inner_ts = rust_type_to_typescript(inner_ty);
136                            return format!("({}) | undefined", inner_ts);
137                        }
138                        return "any | undefined".to_string();
139                    }
140                    "Vec" | "LinkedList" | "VecDeque" => {
141                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
142                            && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
143                        {
144                            let inner_ts = rust_type_to_typescript(inner_ty);
145                            return format!("({})[]", inner_ts);
146                        }
147                        return "any[]".to_string();
148                    }
149                    "HashSet" | "BTreeSet" => {
150                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
151                            && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
152                        {
153                            let inner_ts = rust_type_to_typescript(inner_ty);
154                            return format!("Set<{}>", inner_ts);
155                        }
156                        return "Set<any>".to_string();
157                    }
158                    "HashMap" | "BTreeMap" => {
159                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
160                            let args_vec: Vec<_> = args.args.iter().collect();
161                            if args_vec.len() == 2 {
162                                let key_ts = if let syn::GenericArgument::Type(key_ty) = args_vec[0] {
163                                    rust_type_to_typescript(key_ty)
164                                }
165                                else {
166                                    "any".to_string()
167                                };
168                                let value_ts = if let syn::GenericArgument::Type(value_ty) = args_vec[1] {
169                                    rust_type_to_typescript(value_ty)
170                                }
171                                else {
172                                    "any".to_string()
173                                };
174                                return format!("Record<{}, {}>", key_ts, value_ts);
175                            }
176                        }
177                        return "Record<string, any>".to_string();
178                    }
179                    _ => {
180                        // 处理泛型类型,如 Container<T>
181                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
182                            let generic_args: Vec<String> = args
183                                .args
184                                .iter()
185                                .filter_map(|arg| {
186                                    if let syn::GenericArgument::Type(inner_ty) = arg {
187                                        Some(rust_type_to_typescript(inner_ty))
188                                    }
189                                    else if let syn::GenericArgument::Lifetime(_) = arg {
190                                        None // 忽略生命周期参数
191                                    }
192                                    else {
193                                        Some(quote!(#arg).to_string())
194                                    }
195                                })
196                                .collect();
197
198                            if !generic_args.is_empty() {
199                                return format!("{}<{}>", ident, generic_args.join(", "));
200                            }
201                        }
202                        // 保留其他类型的名称,可能是用户定义的类型或泛型参数
203                        ident_str
204                    }
205                }
206            }
207            else {
208                quote!(#ty).to_string()
209            }
210        }
211        Type::Tuple(type_tuple) => {
212            if type_tuple.elems.is_empty() {
213                return "null".to_string();
214            }
215            let elements: Vec<String> = type_tuple.elems.iter().map(rust_type_to_typescript).collect();
216            format!("[{}]", elements.join(", "))
217        }
218        Type::Slice(slice) => {
219            let inner_ts = rust_type_to_typescript(&slice.elem);
220            format!("({})[]", inner_ts)
221        }
222        Type::Reference(ref_type) => {
223            let inner_ty = &ref_type.elem;
224            rust_type_to_typescript(inner_ty)
225        }
226        _ => {
227            let type_str = quote!(#ty).to_string();
228            // 处理基本类型(使用缓存)
229            if let Some(ts_type) = BASIC_TYPE_MAP.get(type_str.as_str()) {
230                return ts_type.to_string();
231            }
232            // 处理空类型
233            if type_str == "()" {
234                return "null".to_string();
235            }
236            type_str
237        }
238    }
239}
240
241/// 为 Rust 结构体生成 TypeScript 类定义
242///
243/// # 示例
244///
245/// ```rust
246/// use typescript_macros::TypescriptClass;
247///
248/// #[derive(TypescriptClass)]
249/// struct User {
250///     id: u32,
251///     name: String,
252///     active: bool,
253/// }
254/// ```
255///
256/// 这将生成对应的 TypeScript 类定义:
257///
258/// ```typescript
259/// class User {
260///     id: number;
261///     name: string;
262///     active: boolean;
263///     
264///     constructor(id: number, name: string, active: boolean) {
265///         this.id = id;
266///         this.name = name;
267///         this.active = active;
268///     }
269/// }
270/// ```
271#[proc_macro_derive(TypescriptClass, attributes(ts))]
272pub fn typescript_class_derive(input: TokenStream) -> TokenStream {
273    let input = parse_macro_input!(input as DeriveInput);
274
275    let struct_name = &input.ident;
276
277    // 解析泛型参数
278    let generic_params = if let Some(_generics) = input.generics.params.first() {
279        let generic_args: Vec<String> =
280            input
281                .generics
282                .params
283                .iter()
284                .filter_map(|param| {
285                    if let syn::GenericParam::Type(type_param) = param { Some(type_param.ident.to_string()) } else { None }
286                })
287                .collect();
288
289        if !generic_args.is_empty() { format!("<{}>", generic_args.join(", ")) } else { "".to_string() }
290    }
291    else {
292        "".to_string()
293    };
294
295    let fields = match &input.data {
296        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(fields), .. }) => &fields.named,
297        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unnamed(_), .. }) => {
298            return syn::Error::new_spanned(input, "Tuple structs are not supported").to_compile_error().into();
299        }
300        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unit, .. }) => {
301            return syn::Error::new_spanned(input, "Unit structs are not supported").to_compile_error().into();
302        }
303        syn::Data::Enum(_) => {
304            return syn::Error::new_spanned(input, "Enums are not supported").to_compile_error().into();
305        }
306        syn::Data::Union(_) => {
307            return syn::Error::new_spanned(input, "Unions are not supported").to_compile_error().into();
308        }
309    };
310
311    for field in fields {
312        if field.ident.is_none() {
313            return syn::Error::new_spanned(field, "All fields must have names").to_compile_error().into();
314        }
315    }
316
317    let field_types: Vec<String> = fields
318        .iter()
319        .filter_map(|field| {
320            let attrs = FieldAttributes::from_attributes(&field.attrs);
321            if attrs.skip {
322                return None;
323            }
324            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
325            let mut ts_type = if let Some(custom_type) = &attrs.custom_type {
326                custom_type.clone()
327            }
328            else {
329                rust_type_to_typescript(&field.ty)
330            };
331            if attrs.optional && !ts_type.contains("undefined") {
332                ts_type = format!("{} | undefined", ts_type);
333            }
334            Some(format!("    {}: {}", field_name, ts_type))
335        })
336        .collect();
337
338    let constructor_params: Vec<String> = fields
339        .iter()
340        .filter_map(|field| {
341            let attrs = FieldAttributes::from_attributes(&field.attrs);
342            if attrs.skip {
343                return None;
344            }
345            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
346            let mut ts_type = if let Some(custom_type) = &attrs.custom_type {
347                custom_type.clone()
348            }
349            else {
350                rust_type_to_typescript(&field.ty)
351            };
352            if attrs.optional && !ts_type.contains("undefined") {
353                ts_type = format!("{} | undefined", ts_type);
354            }
355            Some(format!("        {}: {}", field_name, ts_type))
356        })
357        .collect();
358
359    let constructor_assignments: Vec<String> = fields
360        .iter()
361        .filter_map(|field| {
362            let attrs = FieldAttributes::from_attributes(&field.attrs);
363            if attrs.skip {
364                return None;
365            }
366            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
367            let original_field_name = field.ident.as_ref().unwrap();
368            Some(format!("        this.{} = {};", field_name, original_field_name))
369        })
370        .collect();
371
372    let field_types_str = field_types.join(";\n");
373    let constructor_params_str = constructor_params.join(",\n");
374    let constructor_assignments_str = constructor_assignments.join("\n");
375
376    let ts_code = quote! {
377        impl #struct_name {
378            /// TypeScript 类定义
379            pub const TS_CLASS_DEFINITION: &'static str = concat!(
380                "class ", stringify!(#struct_name), #generic_params, " {\n",
381                #field_types_str, ";\n",
382                "\n",
383                "    constructor(\n",
384                #constructor_params_str, ",\n",
385                "    ) {\n",
386                #constructor_assignments_str, "\n",
387                "    }\n",
388                "}\n"
389            );
390
391            /// 获取 TypeScript 类定义
392            pub fn ts_class_definition() -> &'static str {
393                Self::TS_CLASS_DEFINITION
394            }
395        }
396    };
397
398    TokenStream::from(ts_code)
399}
400
401/// 为 Rust 函数生成 TypeScript 函数类型定义
402///
403/// # 示例
404///
405/// ```rust
406/// use typescript_macros::typescript_function;
407///
408/// #[typescript_function]
409/// fn add(a: u32, b: u32) -> u32 {
410///     a + b
411/// }
412/// ```
413///
414/// 这将生成对应的 TypeScript 函数类型定义:
415///
416/// ```typescript
417/// type AddFunction = (a: number, b: number) => number;
418/// ```
419#[proc_macro_attribute]
420pub fn typescript_function(_args: TokenStream, input: TokenStream) -> TokenStream {
421    let input = parse_macro_input!(input as ItemFn);
422
423    let fn_name = &input.sig.ident;
424
425    let params = &input.sig.inputs;
426    let param_types: Vec<String> = params
427        .iter()
428        .map(|param| match param {
429            syn::FnArg::Typed(pat_type) => {
430                let param_name = match &*pat_type.pat {
431                    syn::Pat::Ident(pat_ident) => &pat_ident.ident,
432                    _ => return "_: any".to_string(),
433                };
434                let ts_type = rust_type_to_typescript(&pat_type.ty);
435                format!("{}: {}", param_name, ts_type)
436            }
437            _ => "_: any".to_string(),
438        })
439        .collect();
440
441    let return_type = match &input.sig.output {
442        syn::ReturnType::Type(_, ty) => rust_type_to_typescript(ty),
443        syn::ReturnType::Default => "void".to_string(),
444    };
445
446    let param_types_str = param_types.join(", ");
447    let ts_const_name = format!("{}_TS_FUNCTION_DEFINITION", fn_name);
448    let ts_const_ident = syn::Ident::new(&ts_const_name, fn_name.span());
449    let ts_function_name = syn::Ident::new(&format!("{}_ts_function_definition", fn_name), fn_name.span());
450
451    let ts_code = quote! {
452        #input
453
454        /// TypeScript 函数类型定义
455        pub const #ts_const_ident: &'static str = concat!(
456            "type ", stringify!(#fn_name), "Function = (",
457            #param_types_str,
458            ") => ",
459            #return_type,
460            ";\n"
461        );
462
463        /// 获取 TypeScript 函数类型定义
464        pub fn #ts_function_name() -> &'static str {
465            #ts_const_ident
466        }
467    };
468
469    TokenStream::from(ts_code)
470}
471
472/// 为 Rust 函数生成 TypeScript 函数声明
473///
474/// # 示例
475///
476/// ```rust
477/// use typescript_macros::typescript_function_declaration;
478///
479/// #[typescript_function_declaration]
480/// fn add(a: u32, b: u32) -> u32 {
481///     a + b
482/// }
483/// ```
484///
485/// 这将生成对应的 TypeScript 函数声明:
486///
487/// ```typescript
488/// function add(a: number, b: number): number;
489/// ```
490#[proc_macro_attribute]
491pub fn typescript_function_declaration(_args: TokenStream, input: TokenStream) -> TokenStream {
492    let input = parse_macro_input!(input as ItemFn);
493
494    let fn_name = &input.sig.ident;
495
496    let params = &input.sig.inputs;
497    let param_types: Vec<String> = params
498        .iter()
499        .map(|param| match param {
500            syn::FnArg::Typed(pat_type) => {
501                let param_name = match &*pat_type.pat {
502                    syn::Pat::Ident(pat_ident) => &pat_ident.ident,
503                    _ => return "_: any".to_string(),
504                };
505                let ts_type = rust_type_to_typescript(&pat_type.ty);
506                format!("{}: {}", param_name, ts_type)
507            }
508            _ => "_: any".to_string(),
509        })
510        .collect();
511
512    let return_type = match &input.sig.output {
513        syn::ReturnType::Type(_, ty) => rust_type_to_typescript(ty),
514        syn::ReturnType::Default => "void".to_string(),
515    };
516
517    let param_types_str = param_types.join(", ");
518    let ts_const_name = format!("{}_TS_FUNCTION_DECLARATION", fn_name);
519    let ts_const_ident = syn::Ident::new(&ts_const_name, fn_name.span());
520    let ts_function_name = syn::Ident::new(&format!("{}_ts_function_declaration", fn_name), fn_name.span());
521
522    let ts_code = quote! {
523        #input
524
525        /// TypeScript 函数声明
526        pub const #ts_const_ident: &'static str = concat!(
527            "function ", stringify!(#fn_name), "(",
528            #param_types_str,
529            "): ",
530            #return_type,
531            ";\n"
532        );
533
534        /// 获取 TypeScript 函数声明
535        pub fn #ts_function_name() -> &'static str {
536            #ts_const_ident
537        }
538    };
539
540    TokenStream::from(ts_code)
541}
542
543/// 为 Rust 模块生成 TypeScript 命名空间
544///
545/// # 示例
546///
547/// ```rust
548/// use typescript_macros::typescript_namespace;
549///
550/// #[typescript_namespace("utils")]
551/// mod utils {
552///     pub fn add(a: u32, b: u32) -> u32 {
553///         a + b
554///     }
555///
556///     pub struct Point {
557///         pub x: f64,
558///         pub y: f64,
559///     }
560/// }
561/// ```
562///
563/// 这将生成对应的 TypeScript 命名空间:
564///
565/// ```typescript
566/// namespace utils {
567///     export function add(a: number, b: number): number;
568///     
569///     export interface Point {
570///         x: number;
571///         y: number;
572///     }
573/// }
574/// ```
575#[proc_macro_attribute]
576pub fn typescript_namespace(args: TokenStream, input: TokenStream) -> TokenStream {
577    let namespace_name = parse_macro_input!(args as syn::LitStr).value();
578    let input = parse_macro_input!(input as syn::ItemMod);
579
580    let mod_name = &input.ident;
581    let _content = &input.content;
582
583    // 这里简化处理,实际项目中可能需要更复杂的解析
584    // 目前我们只是生成命名空间的框架
585
586    let ts_const_name = format!("{}_TS_NAMESPACE", mod_name);
587    let ts_const_ident = syn::Ident::new(&ts_const_name, mod_name.span());
588    let ts_namespace_name = syn::Ident::new(&format!("{}_ts_namespace", mod_name), mod_name.span());
589
590    let ts_code = quote! {
591        #input
592
593        /// TypeScript 命名空间定义
594        pub const #ts_const_ident: &'static str = concat!(
595            "namespace ", #namespace_name, " {\n",
596            "    // 模块内容将在这里生成\n",
597            "}\n"
598        );
599
600        /// 获取 TypeScript 命名空间定义
601        pub fn #ts_namespace_name() -> &'static str {
602            #ts_const_ident
603        }
604    };
605
606    TokenStream::from(ts_code)
607}
608
609/// 为 Rust 类型生成 TypeScript 类型守卫
610///
611/// # 示例
612///
613/// ```rust
614/// use typescript_macros::TypescriptGuard;
615///
616/// #[derive(TypescriptGuard)]
617/// struct User {
618///     id: u32,
619///     name: String,
620///     active: bool,
621/// }
622/// ```
623///
624/// 这将生成对应的 TypeScript 类型守卫:
625///
626/// ```typescript
627/// function isUser(obj: any): obj is User {
628///     return (
629///         typeof obj === 'object' && obj !== null &&
630///         typeof obj.id === 'number' &&
631///         typeof obj.name === 'string' &&
632///         typeof obj.active === 'boolean'
633///     );
634/// }
635/// ```
636#[proc_macro_derive(TypescriptGuard, attributes(ts))]
637pub fn typescript_guard_derive(input: TokenStream) -> TokenStream {
638    let input = parse_macro_input!(input as DeriveInput);
639
640    let type_name = &input.ident;
641
642    if !matches!(input.data, syn::Data::Struct(_)) {
643        return syn::Error::new_spanned(input, "Only structs are supported for TypescriptGuard").to_compile_error().into();
644    }
645
646    let fields = match &input.data {
647        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(fields), .. }) => &fields.named,
648        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unnamed(_), .. }) => {
649            return syn::Error::new_spanned(input, "Tuple structs are not supported").to_compile_error().into();
650        }
651        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unit, .. }) => {
652            return syn::Error::new_spanned(input, "Unit structs are not supported").to_compile_error().into();
653        }
654        _ => unreachable!(),
655    };
656
657    let guard_conditions: Vec<String> = fields
658        .iter()
659        .filter_map(|field| {
660            let attrs = FieldAttributes::from_attributes(&field.attrs);
661            if attrs.skip {
662                return None;
663            }
664            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
665            let ts_type = rust_type_to_typescript(&field.ty);
666
667            // 根据 TypeScript 类型生成相应的类型检查
668            let condition = match ts_type.as_str() {
669                "number" => format!("typeof obj.{field_name} === 'number'"),
670                "boolean" => format!("typeof obj.{field_name} === 'boolean'"),
671                "string" => format!("typeof obj.{field_name} === 'string'"),
672                _ => format!("obj.{field_name} !== undefined"), // 对于复杂类型,只检查存在性
673            };
674
675            Some(condition)
676        })
677        .collect();
678
679    let guard_conditions_str = guard_conditions.join(" &&\n        ");
680    let guard_function_name = format!("is{}", type_name);
681    let guard_function_ident = syn::Ident::new(&guard_function_name, type_name.span());
682
683    let ts_const_name = format!("{}_TS_GUARD", type_name);
684    let ts_const_ident = syn::Ident::new(&ts_const_name, type_name.span());
685    let ts_guard_name = syn::Ident::new(&format!("{}_ts_guard", type_name), type_name.span());
686
687    let ts_code = quote! {
688        impl #type_name {
689            /// TypeScript 类型守卫函数
690            pub const #ts_const_ident: &'static str = concat!(
691                "function ", stringify!(#guard_function_ident), "(obj: any): obj is ", stringify!(#type_name), " {\n",
692                "    return (\n",
693                "        typeof obj === 'object' && obj !== null &&\n",
694                "        #guard_conditions_str\n",
695                "    );\n",
696                "}\n"
697            );
698
699            /// 获取 TypeScript 类型守卫函数
700            pub fn #ts_guard_name() -> &'static str {
701                Self::#ts_const_ident
702            }
703        }
704    };
705
706    TokenStream::from(ts_code)
707}
708
709/// 为 Rust 接口生成 TypeScript 接口定义
710///
711/// # 示例
712///
713/// ```rust
714/// use typescript_macros::TypescriptInterface;
715///
716/// #[derive(TypescriptInterface)]
717/// struct User {
718///     id: u32,
719///     name: String,
720///     active: bool,
721/// }
722/// ```
723///
724/// 这将生成对应的 TypeScript 接口定义:
725///
726/// ```typescript
727/// interface User {
728///     id: number;
729///     name: string;
730///     active: boolean;
731/// }
732/// ```
733#[proc_macro_derive(TypescriptInterface, attributes(ts))]
734pub fn typescript_interface_derive(input: TokenStream) -> TokenStream {
735    let input = parse_macro_input!(input as DeriveInput);
736
737    let trait_name = &input.ident;
738
739    // 解析泛型参数
740    let generic_params = if let Some(_generics) = input.generics.params.first() {
741        let generic_args: Vec<String> =
742            input
743                .generics
744                .params
745                .iter()
746                .filter_map(|param| {
747                    if let syn::GenericParam::Type(type_param) = param { Some(type_param.ident.to_string()) } else { None }
748                })
749                .collect();
750
751        if !generic_args.is_empty() { format!("<{}>", generic_args.join(", ")) } else { "".to_string() }
752    }
753    else {
754        "".to_string()
755    };
756
757    if !matches!(input.data, syn::Data::Struct(_)) {
758        return syn::Error::new_spanned(input, "Only structs are supported for TypescriptInterface").to_compile_error().into();
759    }
760
761    let fields = match &input.data {
762        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(fields), .. }) => &fields.named,
763        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unnamed(_), .. }) => {
764            return syn::Error::new_spanned(input, "Tuple structs are not supported").to_compile_error().into();
765        }
766        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unit, .. }) => {
767            return syn::Error::new_spanned(input, "Unit structs are not supported").to_compile_error().into();
768        }
769        _ => unreachable!(),
770    };
771
772    let interface_fields: Vec<String> = fields
773        .iter()
774        .filter_map(|field| {
775            let attrs = FieldAttributes::from_attributes(&field.attrs);
776            if attrs.skip {
777                return None;
778            }
779            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
780            let mut ts_type = if let Some(custom_type) = &attrs.custom_type {
781                custom_type.clone()
782            }
783            else {
784                rust_type_to_typescript(&field.ty)
785            };
786            if attrs.optional && !ts_type.contains("undefined") {
787                ts_type = format!("{} | undefined", ts_type);
788            }
789            Some(format!("    {}: {}", field_name, ts_type))
790        })
791        .collect();
792
793    let interface_fields_str = interface_fields.join(";\n");
794
795    let ts_code = quote! {
796        impl #trait_name {
797            /// TypeScript 接口定义
798            pub const TS_INTERFACE_DEFINITION: &'static str = concat!(
799                "interface ", stringify!(#trait_name), #generic_params, " {\n",
800                #interface_fields_str, ";\n",
801                "}\n"
802            );
803
804            /// 获取 TypeScript 接口定义
805            pub fn ts_interface_definition() -> &'static str {
806                Self::TS_INTERFACE_DEFINITION
807            }
808        }
809    };
810
811    TokenStream::from(ts_code)
812}
813
814/// 为 Rust 枚举生成 TypeScript 枚举定义
815///
816/// # 示例
817///
818/// ```rust
819/// use typescript_macros::TypescriptEnum;
820///
821/// #[derive(TypescriptEnum)]
822/// enum Color {
823///     Red,
824///     Green,
825///     Blue,
826/// }
827/// ```
828///
829/// 这将生成对应的 TypeScript 枚举定义:
830///
831/// ```typescript
832/// enum Color {
833///     Red = 0,
834///     Green = 1,
835///     Blue = 2,
836/// }
837/// ```
838///
839/// # 字符串枚举示例
840///
841/// ```rust
842/// use typescript_macros::TypescriptEnum;
843///
844/// #[derive(TypescriptEnum)]
845/// enum Direction {
846///     Up,
847///     Down,
848///     Left,
849///     Right,
850/// }
851/// ```
852///
853/// 这将生成对应的 TypeScript 枚举定义:
854///
855/// ```typescript
856/// enum Direction {
857///     Up = "UP",
858///     Down = "DOWN",
859///     Left = "LEFT",
860///     Right = "RIGHT",
861/// }
862/// ```
863#[proc_macro_derive(TypescriptEnum, attributes(ts))]
864pub fn typescript_enum_derive(input: TokenStream) -> TokenStream {
865    let input = parse_macro_input!(input as DeriveInput);
866
867    let enum_name = &input.ident;
868
869    let variants = match &input.data {
870        syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
871        _ => {
872            return syn::Error::new_spanned(input, "Only enums are supported").to_compile_error().into();
873        }
874    };
875
876    let enum_variants: Vec<String> = variants
877        .iter()
878        .enumerate()
879        .map(|(index, variant)| {
880            let variant_name = &variant.ident;
881
882            Ok(match &variant.fields {
883                syn::Fields::Unit => {
884                    format!("    {} = {}", variant_name, index)
885                }
886                syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
887                    let value = &fields.unnamed[0];
888                    let value_str = quote!(#value).to_string();
889                    format!("    {} = {}", variant_name, value_str)
890                }
891                _ => {
892                    return Err(syn::Error::new_spanned(
893                        variant,
894                        "Only unit variants or variants with a single value are supported",
895                    ));
896                }
897            })
898        })
899        .collect::<Result<Vec<_>, _>>()
900        .unwrap();
901
902    let enum_variants_str = enum_variants.join(",\n");
903
904    let ts_code = quote! {
905        impl #enum_name {
906            /// TypeScript 枚举定义
907            pub const TS_ENUM_DEFINITION: &'static str = concat!(
908                "enum ", stringify!(#enum_name), " {\n",
909                #enum_variants_str,
910                "\n}"
911            );
912
913            /// 获取 TypeScript 枚举定义
914            pub fn ts_enum_definition() -> &'static str {
915                Self::TS_ENUM_DEFINITION
916            }
917        }
918    };
919
920    TokenStream::from(ts_code)
921}
922
923/// 为 Rust 类型定义生成 TypeScript 类型别名
924///
925/// # 示例
926///
927/// ```rust
928/// use typescript_macros::TypescriptType;
929///
930/// #[derive(TypescriptType)]
931/// struct User {
932///     id: u32,
933///     name: String,
934///     active: bool,
935/// }
936/// ```
937///
938/// 这将生成对应的 TypeScript 类型别名:
939///
940/// ```typescript
941/// type User = {
942///     id: number;
943///     name: string;
944///     active: boolean;
945/// };
946/// ```
947#[proc_macro_derive(TypescriptType, attributes(ts))]
948pub fn typescript_type_derive(input: TokenStream) -> TokenStream {
949    let input = parse_macro_input!(input as DeriveInput);
950
951    let type_name = &input.ident;
952
953    // 解析泛型参数
954    let generic_params = if let Some(_generics) = input.generics.params.first() {
955        let generic_args: Vec<String> =
956            input
957                .generics
958                .params
959                .iter()
960                .filter_map(|param| {
961                    if let syn::GenericParam::Type(type_param) = param { Some(type_param.ident.to_string()) } else { None }
962                })
963                .collect();
964
965        if !generic_args.is_empty() { format!("<{}>", generic_args.join(", ")) } else { "".to_string() }
966    }
967    else {
968        "".to_string()
969    };
970
971    if !matches!(input.data, syn::Data::Struct(_)) {
972        return syn::Error::new_spanned(input, "Only structs are supported for TypescriptType").to_compile_error().into();
973    }
974
975    let fields = match &input.data {
976        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(fields), .. }) => &fields.named,
977        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unnamed(_), .. }) => {
978            return syn::Error::new_spanned(input, "Tuple structs are not supported").to_compile_error().into();
979        }
980        syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Unit, .. }) => {
981            return syn::Error::new_spanned(input, "Unit structs are not supported").to_compile_error().into();
982        }
983        _ => unreachable!(),
984    };
985
986    let type_fields: Vec<String> = fields
987        .iter()
988        .filter_map(|field| {
989            let attrs = FieldAttributes::from_attributes(&field.attrs);
990            if attrs.skip {
991                return None;
992            }
993            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
994            let mut ts_type = if let Some(custom_type) = &attrs.custom_type {
995                custom_type.clone()
996            }
997            else {
998                rust_type_to_typescript(&field.ty)
999            };
1000            if attrs.optional && !ts_type.contains("undefined") {
1001                ts_type = format!("{} | undefined", ts_type);
1002            }
1003            Some(format!("    {}: {}", field_name, ts_type))
1004        })
1005        .collect();
1006
1007    let type_fields_str = type_fields.join(";\n");
1008
1009    let ts_code = quote! {
1010        impl #type_name {
1011            /// TypeScript 类型别名定义
1012            pub const TS_TYPE_DEFINITION: &'static str = concat!(
1013                "type ", stringify!(#type_name), #generic_params, " = {\n",
1014                #type_fields_str, ";\n",
1015                "}\n"
1016            );
1017
1018            /// 获取 TypeScript 类型别名定义
1019            pub fn ts_type_definition() -> &'static str {
1020                Self::TS_TYPE_DEFINITION
1021            }
1022        }
1023    };
1024
1025    TokenStream::from(ts_code)
1026}
1027
1028/// 为 Rust 枚举生成 TypeScript 联合类型
1029///
1030/// # 示例
1031///
1032/// ```rust
1033/// use typescript_macros::TypescriptUnion;
1034///
1035/// #[derive(TypescriptUnion)]
1036/// enum Shape {
1037///     Circle { radius: f64 },
1038///     Rectangle { width: f64, height: f64 },
1039///     Triangle { base: f64, height: f64 },
1040/// }
1041/// ```
1042///
1043/// 这将生成对应的 TypeScript 联合类型:
1044///
1045/// ```typescript
1046/// type Shape =
1047///     | { type: "Circle"; radius: number }
1048///     | { type: "Rectangle"; width: number; height: number }
1049///     | { type: "Triangle"; base: number; height: number };
1050/// ```
1051#[proc_macro_derive(TypescriptUnion, attributes(ts))]
1052pub fn typescript_union_derive(input: TokenStream) -> TokenStream {
1053    let input = parse_macro_input!(input as DeriveInput);
1054
1055    let enum_name = &input.ident;
1056
1057    let variants = match &input.data {
1058        syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
1059        _ => {
1060            return syn::Error::new_spanned(input, "Only enums are supported for TypescriptUnion").to_compile_error().into();
1061        }
1062    };
1063
1064    let union_variants: Vec<String> = variants
1065        .iter()
1066        .map(|variant| {
1067            let variant_name = variant.ident.to_string();
1068
1069            Ok(match &variant.fields {
1070                syn::Fields::Unit => {
1071                    format!("    | {{ type: \"{}\" }}", variant_name)
1072                }
1073                syn::Fields::Named(fields) => {
1074                    let field_defs: Vec<String> = fields
1075                        .named
1076                        .iter()
1077                        .filter_map(|field| {
1078                            let attrs = FieldAttributes::from_attributes(&field.attrs);
1079                            if attrs.skip {
1080                                return None;
1081                            }
1082                            let field_name = attrs.rename.clone().unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
1083                            let mut ts_type = if let Some(custom_type) = &attrs.custom_type {
1084                                custom_type.clone()
1085                            }
1086                            else {
1087                                rust_type_to_typescript(&field.ty)
1088                            };
1089                            if attrs.optional && !ts_type.contains("undefined") {
1090                                ts_type = format!("{} | undefined", ts_type);
1091                            }
1092                            Some(format!("    {}: {}", field_name, ts_type))
1093                        })
1094                        .collect();
1095
1096                    let fields_str =
1097                        if field_defs.is_empty() { "".to_string() } else { format!(";\n{}", field_defs.join(";\n")) };
1098
1099                    format!("    | {{ type: \"{}\"{}}}", variant_name, fields_str)
1100                }
1101                syn::Fields::Unnamed(fields) => {
1102                    if fields.unnamed.is_empty() {
1103                        format!("    | \"{}\"", variant_name)
1104                    }
1105                    else if fields.unnamed.len() == 1 {
1106                        let field_ty = &fields.unnamed[0].ty;
1107                        let ts_type = rust_type_to_typescript(field_ty);
1108                        format!("    | {{ type: \"{}\"; value: {} }}", variant_name, ts_type)
1109                    }
1110                    else {
1111                        return Err(syn::Error::new_spanned(
1112                            variant,
1113                            "Only unit variants, named variants, or variants with a single value are supported",
1114                        ));
1115                    }
1116                }
1117            })
1118        })
1119        .collect::<Result<Vec<_>, _>>()
1120        .unwrap();
1121
1122    let union_variants_str = union_variants.join("\n");
1123
1124    let ts_code = quote! {
1125        impl #enum_name {
1126            /// TypeScript 联合类型定义
1127            pub const TS_UNION_DEFINITION: &'static str = concat!(
1128                "type ", stringify!(#enum_name), " =\n",
1129                #union_variants_str,
1130                ";\n"
1131            );
1132
1133            /// 获取 TypeScript 联合类型定义
1134            pub fn ts_union_definition() -> &'static str {
1135                Self::TS_UNION_DEFINITION
1136            }
1137        }
1138    };
1139
1140    TokenStream::from(ts_code)
1141}