Skip to main content

ts_function/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    Error, FnArg, GenericArgument, Ident, Item, ItemImpl, ItemType, PathArguments, ReturnType,
5    Type, parse_macro_input,
6};
7
8#[macro_use]
9mod ts_type;
10mod ts_macro;
11
12use crate::ts_type::ToTsType;
13
14/// Generates TypeScript type aliases and `wasm-bindgen` ABI trait implementations
15/// for Rust items.
16///
17/// This attribute can be applied to:
18/// 1. **Structs**: Generates a TypeScript interface and property bindings.
19/// 2. **Enums**: Automatically applies `#[wasm_bindgen]` (supports C-style enums).
20/// 3. **Type Aliases**: Generates a TypeScript type alias and a typed function wrapper.
21/// 4. **Impl Blocks**: Generates a typed function wrapper for a `call` method.
22///
23/// # Examples
24///
25/// **Struct Usage**
26///
27/// ```rust,ignore
28/// #[ts(rename_all = "camelCase")]
29/// struct MyStruct {
30///     field_name: String,
31/// }
32/// ```
33///
34/// **Enum Usage**
35///
36/// ```rust,ignore
37/// #[ts]
38/// enum Status { Active, Inactive }
39/// ```
40///
41/// **Function Wrapper Usage**
42///
43/// ```rust,ignore
44/// #[ts]
45/// pub type OnReady = fn(msg: String);
46///
47/// #[ts]
48/// struct AppFunctions {
49///     on_ready: OnReady,
50/// }
51/// ```
52#[proc_macro_attribute]
53pub fn ts(attr: TokenStream, input: TokenStream) -> TokenStream {
54    let item = parse_macro_input!(input as Item);
55    ts_internal_dispatcher(attr.into(), item).into()
56}
57
58fn ts_internal_dispatcher(attr: proc_macro2::TokenStream, item: Item) -> proc_macro2::TokenStream {
59    let attr_args = attr.clone();
60
61    match &item {
62        Item::Struct(item_struct) => {
63            let args = match syn::parse2::<ts_macro::TsArgs>(attr_args) {
64                Ok(args) => args,
65                Err(err) => return err.to_compile_error(),
66            };
67            ts_macro::ts_internal(args, item_struct.clone())
68        }
69        Item::Enum(item_enum) => {
70            quote! {
71                #[::wasm_bindgen::prelude::wasm_bindgen]
72                #item_enum
73            }
74        }
75        Item::Type(item_type) => match parse_item_type(item_type) {
76            Ok(tokens) => tokens,
77            Err(err) => err.to_compile_error(),
78        },
79        Item::Impl(item_impl) => match parse_item_impl(item_impl) {
80            Ok(tokens) => tokens,
81            Err(err) => err.to_compile_error(),
82        },
83        _ => Error::new_spanned(
84            item,
85            "#[ts] can only be applied to a struct, enum, type alias, or impl block",
86        )
87        .to_compile_error(),
88    }
89}
90
91struct ParsedSignature<'a> {
92    struct_ident: &'a Ident,
93    args: Vec<(Ident, &'a Type)>,
94    output: &'a ReturnType,
95}
96
97fn generate_return_conversion(ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
98    match ty {
99        Type::Path(type_path) => {
100            let segment = type_path
101                .path
102                .segments
103                .last()
104                .ok_or_else(|| Error::new_spanned(ty, "Expected a type segment"))?;
105            let ident = &segment.ident;
106            let ident_str = ident.to_string();
107
108            if let Some(inner_ty) = get_slice_element_type(ty)
109                && let Some(arr_type) = get_typed_array_ident(inner_ty)
110            {
111                return Ok(quote! {
112                    let arr: ::js_sys::#arr_type = ::wasm_bindgen::JsCast::unchecked_into(res);
113                    Ok(::std::convert::Into::<#ty>::into(arr.to_vec()))
114                });
115            }
116
117            match ident_str.as_str() {
118                "f32" | "f64" | "i8" | "i16" | "i32" | "u8" | "u16" | "u32" => Ok(quote! {
119                    res.as_f64().map(|v| v as #ty).ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a number"))
120                }),
121                "i64" | "u64" => Ok(quote! {
122                    ::std::convert::TryInto::<#ty>::try_into(res).map_err(|_| ::wasm_bindgen::JsValue::from_str("Expected a BigInt"))
123                }),
124                "bool" => Ok(quote! {
125                    res.as_bool().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a boolean"))
126                }),
127                "String" => Ok(quote! {
128                    res.as_string().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a string"))
129                }),
130                "JsValue" => Ok(quote! {
131                    Ok(res)
132                }),
133                "Option" => {
134                    let PathArguments::AngleBracketed(args) = &segment.arguments else {
135                        return Err(Error::new_spanned(
136                            ty,
137                            "Expected generic argument for Option",
138                        ));
139                    };
140                    let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() else {
141                        return Err(Error::new_spanned(ty, "Expected type argument for Option"));
142                    };
143                    let inner_conversion = generate_return_conversion(inner_ty)?;
144                    Ok(quote! {
145                        if res.is_null() || res.is_undefined() {
146                            Ok(None)
147                        } else {
148                            let res = { #inner_conversion };
149                            res.map(Some)
150                        }
151                    })
152                }
153                _ => Ok(quote! {
154                    Ok(::wasm_bindgen::JsCast::unchecked_into::<#ty>(res))
155                }),
156            }
157        }
158        _ => Err(Error::new_spanned(
159            ty,
160            "Unsupported return type in type alias pattern. Use the `impl` escape hatch instead.",
161        )),
162    }
163}
164
165fn parse_item_type(item_type: &ItemType) -> syn::Result<proc_macro2::TokenStream> {
166    let Type::BareFn(bare_fn) = &*item_type.ty else {
167        return Err(Error::new_spanned(
168            &item_type.ty,
169            "Expected a function pointer type (e.g., `fn(x: f64)`)",
170        ));
171    };
172
173    let struct_ident = &item_type.ident;
174    let mut args = Vec::new();
175
176    for (i, arg) in bare_fn.inputs.iter().enumerate() {
177        let ident = match &arg.name {
178            Some((ident, _)) => ident.clone(),
179            None => format_ident!("arg{}", i),
180        };
181        args.push((ident, &arg.ty));
182    }
183
184    let parsed = ParsedSignature {
185        struct_ident,
186        args: args.clone(),
187        output: &bare_fn.output,
188    };
189
190    let abi_traits = generate_abi_traits(&parsed)?;
191
192    let mut fn_args = Vec::new();
193    let mut arg_conversions = Vec::new();
194    let mut call_args = Vec::new();
195    for (ident, ty) in &args {
196        fn_args.push(quote! { #ident: #ty });
197        let conversion = generate_conversion(ident, ty)?;
198        arg_conversions.push(conversion);
199        call_args.push(quote! { &#ident });
200    }
201
202    let args_len = call_args.len();
203    if args_len > 9 {
204        return Err(Error::new_spanned(
205            item_type,
206            "Functions with more than 9 arguments are not supported yet",
207        ));
208    }
209    let call_method_name = format_ident!("call{}", args_len);
210    let call_method = quote! { #call_method_name(&::wasm_bindgen::JsValue::NULL, #(#call_args),*) };
211
212    let output = parsed.output;
213    let (ret_type, ret_stmt) = match output {
214        ReturnType::Default => (quote! { () }, quote! { self.0.#call_method.map(|_| ()) }),
215        ReturnType::Type(_, ty) => {
216            let conversion = generate_return_conversion(ty)?;
217            (
218                quote! { #ty },
219                quote! {
220                    let res = self.0.#call_method?;
221                    #conversion
222                },
223            )
224        }
225    };
226
227    Ok(quote! {
228        pub struct #struct_ident(pub ::js_sys::Function);
229
230        impl #struct_ident {
231            pub fn call(&self, #(#fn_args),*) -> Result<#ret_type, ::wasm_bindgen::JsValue> {
232                #(#arg_conversions)*
233                #ret_stmt
234            }
235        }
236
237        #abi_traits
238    })
239}
240
241fn generate_conversion(ident: &Ident, ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
242    if let Type::ImplTrait(type_impl) = ty {
243        for bound in &type_impl.bounds {
244            if let syn::TypeParamBound::Trait(trait_bound) = bound
245                && let Some(segment) = trait_bound.path.segments.last()
246                && let PathArguments::AngleBracketed(args) = &segment.arguments
247                && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
248            {
249                match segment.ident.to_string().as_str() {
250                    "Into" => {
251                        let inner_conversion = generate_conversion(ident, inner_ty)?;
252                        return Ok(quote! {
253                            let #ident = ::std::convert::Into::<#inner_ty>::into(#ident);
254                            #inner_conversion
255                        });
256                    }
257                    "AsRef" => {
258                        if let Type::Slice(slice) = inner_ty {
259                            return Ok(generate_typed_array_conversion(ident, &slice.elem));
260                        }
261                    }
262                    _ => {}
263                }
264            }
265        }
266        return Err(Error::new_spanned(
267            ty,
268            "Unsupported `impl Trait`. Only `impl Into<T>` and `impl AsRef<[T]>` are supported.",
269        ));
270    }
271
272    if let Some(inner_ty) = get_slice_element_type(ty) {
273        Ok(generate_typed_array_conversion(ident, inner_ty))
274    } else {
275        Ok(quote! {
276            let #ident = ::std::convert::Into::<::wasm_bindgen::JsValue>::into(#ident);
277        })
278    }
279}
280
281fn generate_typed_array_conversion(ident: &Ident, inner_ty: &Type) -> proc_macro2::TokenStream {
282    if let Some(arr_type) = get_typed_array_ident(inner_ty) {
283        quote! {
284            let #ident = ::wasm_bindgen::JsValue::from(::js_sys::#arr_type::from(::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)));
285        }
286    } else {
287        quote! {
288            let #ident = ::wasm_bindgen::JsValue::from(
289                ::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)
290                    .iter()
291                    .map(::wasm_bindgen::JsValue::from)
292                    .collect::<::js_sys::Array>()
293            );
294        }
295    }
296}
297
298fn get_typed_array_ident(inner_ty: &Type) -> Option<proc_macro2::TokenStream> {
299    let inner_str = match inner_ty {
300        Type::Path(p) => p.path.segments.last().map(|s| s.ident.to_string()),
301        _ => None,
302    };
303
304    match inner_str.as_deref() {
305        Some("u8") => Some(quote! { Uint8Array }),
306        Some("i8") => Some(quote! { Int8Array }),
307        Some("u16") => Some(quote! { Uint16Array }),
308        Some("i16") => Some(quote! { Int16Array }),
309        Some("u32") => Some(quote! { Uint32Array }),
310        Some("i32") => Some(quote! { Int32Array }),
311        Some("f32") => Some(quote! { Float32Array }),
312        Some("f64") => Some(quote! { Float64Array }),
313        Some("u64") => Some(quote! { BigUint64Array }),
314        Some("i64") => Some(quote! { BigInt64Array }),
315        _ => None,
316    }
317}
318
319fn get_slice_element_type(ty: &Type) -> Option<&Type> {
320    match ty {
321        Type::Path(type_path) => {
322            let segment = type_path.path.segments.last()?;
323            // Types that implement AsRef<[T]> and we can easily extract T from AST
324            if matches!(
325                segment.ident.to_string().as_str(),
326                "Vec" | "Box" | "Arc" | "Rc"
327            ) && let PathArguments::AngleBracketed(args) = &segment.arguments
328                && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
329            {
330                if let Type::Slice(slice) = inner {
331                    return Some(&*slice.elem);
332                }
333                return Some(inner);
334            }
335        }
336        Type::Reference(type_ref) => {
337            if let Type::Slice(type_slice) = &*type_ref.elem {
338                return Some(&*type_slice.elem);
339            }
340            return get_slice_element_type(&type_ref.elem);
341        }
342        _ => {}
343    }
344    None
345}
346
347fn parse_item_impl(item_impl: &ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
348    if item_impl.trait_.is_some() {
349        return Err(Error::new_spanned(
350            item_impl,
351            "#[ts_function] cannot be applied to trait impls",
352        ));
353    }
354
355    let Type::Path(type_path) = &*item_impl.self_ty else {
356        return Err(Error::new_spanned(
357            &item_impl.self_ty,
358            "Expected a simple path for the struct",
359        ));
360    };
361
362    let struct_ident = type_path.path.get_ident().ok_or_else(|| {
363        Error::new_spanned(
364            &type_path.path,
365            "Expected a single identifier for the struct",
366        )
367    })?;
368
369    let method = item_impl
370        .items
371        .iter()
372        .find_map(|item| {
373            if let syn::ImplItem::Fn(method) = item
374                && method.sig.ident == "call"
375            {
376                return Some(method);
377            }
378            None
379        })
380        .ok_or_else(|| Error::new_spanned(item_impl, "Missing `call` method in impl block"))?;
381
382    let mut args = Vec::new();
383    let mut inputs_iter = method.sig.inputs.iter();
384
385    // Check first argument is `&self` or `&mut self`
386    match inputs_iter.next() {
387        Some(FnArg::Receiver(_)) => {}
388        _ => {
389            return Err(Error::new_spanned(
390                &method.sig,
391                "The `call` method must take `&self` or `&mut self` as its first parameter",
392            ));
393        }
394    }
395
396    for (i, arg) in inputs_iter.enumerate() {
397        let FnArg::Typed(pat_type) = arg else {
398            return Err(Error::new_spanned(arg, "Expected a typed argument"));
399        };
400
401        let ident = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
402            pat_ident.ident.clone()
403        } else {
404            format_ident!("arg{}", i)
405        };
406
407        args.push((ident, &*pat_type.ty));
408    }
409
410    let parsed = ParsedSignature {
411        struct_ident,
412        args,
413        output: &method.sig.output,
414    };
415
416    let abi_traits = generate_abi_traits(&parsed)?;
417
418    Ok(quote! {
419        #item_impl
420        #abi_traits
421    })
422}
423
424fn generate_abi_traits(parsed: &ParsedSignature) -> syn::Result<proc_macro2::TokenStream> {
425    let struct_ident = parsed.struct_ident;
426    let mut ts_args = Vec::new();
427
428    for (ident, ty) in &parsed.args {
429        let ts_ty = ty
430            .to_ts_type()
431            .map_err(|e| Error::new_spanned(ty, e.message))?
432            .to_string();
433        ts_args.push(format!("{}: {}", ident, ts_ty));
434    }
435
436    let ts_output = match parsed.output {
437        ReturnType::Default => "void".to_string(),
438        ReturnType::Type(_, ty) => ty
439            .to_ts_type()
440            .map_err(|e| Error::new_spanned(ty, e.message))?
441            .to_string(),
442    };
443
444    let ts_string = format!(
445        "type {} = ({}) => {};",
446        struct_ident,
447        ts_args.join(", "),
448        ts_output
449    );
450
451    let generated = quote! {
452        #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
453        const _: &'static str = #ts_string;
454
455        impl ::wasm_bindgen::describe::WasmDescribe for #struct_ident {
456            fn describe() {
457                <::js_sys::Function as ::wasm_bindgen::describe::WasmDescribe>::describe()
458            }
459        }
460
461        impl ::wasm_bindgen::convert::FromWasmAbi for #struct_ident {
462            type Abi = <::js_sys::Function as ::wasm_bindgen::convert::FromWasmAbi>::Abi;
463
464            unsafe fn from_abi(js: Self::Abi) -> Self {
465                Self(::js_sys::Function::from_abi(js))
466            }
467        }
468
469        impl ::wasm_bindgen::convert::OptionFromWasmAbi for #struct_ident {
470            fn is_none(abi: &Self::Abi) -> bool {
471                <::js_sys::Function as ::wasm_bindgen::convert::OptionFromWasmAbi>::is_none(abi)
472            }
473        }
474
475        impl From<::js_sys::Function> for #struct_ident {
476            fn from(f: ::js_sys::Function) -> Self {
477                Self(f)
478            }
479        }
480
481        impl From<#struct_ident> for ::wasm_bindgen::JsValue {
482            fn from(f: #struct_ident) -> Self {
483                ::wasm_bindgen::JsValue::from(f.0)
484            }
485        }
486    };
487
488    Ok(generated)
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use syn::parse_quote;
495
496    #[test]
497    fn test_item_type() {
498        let item_type: ItemType = parse_quote! {
499            pub type OnClick = fn(x: f64, y: impl Into<f64>, arr: js_sys::Float64Array);
500        };
501        let result = parse_item_type(&item_type).unwrap();
502        let result_str = result.to_string();
503
504        assert!(
505            result_str
506                .contains("type OnClick = (x: number, y: number, arr: Float64Array) => void;")
507        );
508        assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
509        assert!(result_str.contains(
510            "pub fn call (& self , x : f64 , y : impl Into < f64 > , arr : js_sys :: Float64Array)"
511        ));
512    }
513
514    #[test]
515    fn test_item_impl() {
516        let item_impl: ItemImpl = parse_quote! {
517            impl OnScroll {
518                pub fn call(&self, y: f64) {
519                    // body
520                }
521            }
522        };
523        let result = parse_item_impl(&item_impl).unwrap();
524        let result_str = result.to_string();
525
526        assert!(result_str.contains("type OnScroll = (y: number) => void;"));
527        assert!(
528            result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
529        );
530    }
531
532    #[test]
533    fn test_dispatcher_item_struct() {
534        let input: Item = parse_quote! {
535            pub struct MyStruct {
536                pub field: f64,
537            }
538        };
539        let attr = quote! {};
540        let result = ts_internal_dispatcher(attr, input);
541        let result_str = result.to_string();
542
543        assert!(result_str.contains("export interface MyStruct"));
544        assert!(result_str.contains("field: number;"));
545    }
546
547    #[test]
548    fn test_dispatcher_item_type() {
549        let input: Item = parse_quote! {
550            pub type OnClick = fn(x: f64);
551        };
552        let attr = quote! {};
553        let result = ts_internal_dispatcher(attr, input);
554        let result_str = result.to_string();
555
556        assert!(result_str.contains("type OnClick = (x: number) => void;"));
557        assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
558    }
559
560    #[test]
561    fn test_dispatcher_item_impl() {
562        let input: Item = parse_quote! {
563            impl OnScroll {
564                pub fn call(&self, y: f64) {}
565            }
566        };
567        let attr = quote! {};
568        let result = ts_internal_dispatcher(attr, input);
569        let result_str = result.to_string();
570
571        assert!(result_str.contains("type OnScroll = (y: number) => void;"));
572        assert!(
573            result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
574        );
575    }
576
577    #[test]
578    fn test_enum_item() {
579        let input: Item = parse_quote! {
580            pub enum Status { Active, Inactive }
581        };
582        let attr = quote! {};
583        let result = ts_internal_dispatcher(attr, input);
584        let result_str = result.to_string();
585
586        assert!(result_str.contains("# [:: wasm_bindgen :: prelude :: wasm_bindgen]"));
587        assert!(result_str.contains("pub enum Status { Active , Inactive }"));
588    }
589
590    #[test]
591    fn test_recursive_generics() {
592        let item_type: ItemType = parse_quote! {
593            pub type ResultFn = fn(res: Result<String, i32>);
594        };
595        let result = parse_item_type(&item_type).unwrap();
596        let result_str = result.to_string();
597
598        assert!(result_str.contains("type ResultFn = (res: Result<string, number>) => void;"));
599
600        let item_type: ItemType = parse_quote! {
601            pub type NestedVecFn = fn(args: Vec<Vec<f64>>);
602        };
603        let result = parse_item_type(&item_type).unwrap();
604        let result_str = result.to_string();
605
606        assert!(result_str.contains("type NestedVecFn = (args: Float64Array[]) => void;"));
607    }
608}