Skip to main content

ts_function/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    Error, FnArg, GenericArgument, Ident, Item, ItemImpl, ItemType, PathArguments, ReturnType,
5    Type, parse_macro_input,
6};
7
8#[macro_use]
9mod ts_type;
10mod ts_macro;
11
12use crate::ts_type::ToTsType;
13
14#[proc_macro_attribute]
15pub fn ts(attr: TokenStream, input: TokenStream) -> TokenStream {
16    ts_macro::ts(attr, input)
17}
18
19struct ParsedSignature<'a> {
20    struct_ident: &'a Ident,
21    args: Vec<(Ident, &'a Type)>,
22    output: &'a ReturnType,
23}
24
25#[proc_macro_attribute]
26pub fn ts_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
27    let item = parse_macro_input!(item as Item);
28
29    let result = match &item {
30        Item::Type(item_type) => parse_item_type(item_type),
31        Item::Impl(item_impl) => parse_item_impl(item_impl),
32        _ => {
33            return Error::new_spanned(
34                item,
35                "#[ts_function] can only be applied to a type alias or an impl block",
36            )
37            .to_compile_error()
38            .into();
39        }
40    };
41
42    match result {
43        Ok(tokens) => tokens.into(),
44        Err(err) => err.to_compile_error().into(),
45    }
46}
47
48fn generate_return_conversion(ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
49    match ty {
50        Type::Path(type_path) => {
51            let segment = type_path.path.segments.last().unwrap();
52            let ident = &segment.ident;
53            let ident_str = ident.to_string();
54
55            if let Some(inner_ty) = get_slice_element_type(ty)
56                && let Some(arr_type) = get_typed_array_ident(inner_ty)
57            {
58                return Ok(quote! {
59                    let arr: ::js_sys::#arr_type = ::wasm_bindgen::JsCast::unchecked_into(res);
60                    Ok(::std::convert::Into::<#ty>::into(arr.to_vec()))
61                });
62            }
63
64            match ident_str.as_str() {
65                "f32" | "f64" | "i8" | "i16" | "i32" | "u8" | "u16" | "u32" => Ok(quote! {
66                    res.as_f64().map(|v| v as #ty).ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a number"))
67                }),
68                "i64" | "u64" => Ok(quote! {
69                    ::std::convert::TryInto::<#ty>::try_into(res).map_err(|_| ::wasm_bindgen::JsValue::from_str("Expected a BigInt"))
70                }),
71                "bool" => Ok(quote! {
72                    res.as_bool().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a boolean"))
73                }),
74                "String" => Ok(quote! {
75                    res.as_string().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a string"))
76                }),
77                "JsValue" => Ok(quote! {
78                    Ok(res)
79                }),
80                "Option" => {
81                    let PathArguments::AngleBracketed(args) = &segment.arguments else {
82                        return Err(Error::new_spanned(
83                            ty,
84                            "Expected generic argument for Option",
85                        ));
86                    };
87                    let syn::GenericArgument::Type(inner_ty) = args.args.first().unwrap() else {
88                        return Err(Error::new_spanned(ty, "Expected type argument for Option"));
89                    };
90                    let inner_conversion = generate_return_conversion(inner_ty)?;
91                    Ok(quote! {
92                        if res.is_null() || res.is_undefined() {
93                            Ok(None)
94                        } else {
95                            let res = { #inner_conversion };
96                            res.map(Some)
97                        }
98                    })
99                }
100                _ => Ok(quote! {
101                    Ok(::wasm_bindgen::JsCast::unchecked_into::<#ty>(res))
102                }),
103            }
104        }
105        _ => Err(Error::new_spanned(
106            ty,
107            "Unsupported return type in type alias pattern. Use the `impl` escape hatch instead.",
108        )),
109    }
110}
111
112fn parse_item_type(item_type: &ItemType) -> syn::Result<proc_macro2::TokenStream> {
113    let Type::BareFn(bare_fn) = &*item_type.ty else {
114        return Err(Error::new_spanned(
115            &item_type.ty,
116            "Expected a function pointer type (e.g., `fn(x: f64)`)",
117        ));
118    };
119
120    let struct_ident = &item_type.ident;
121    let mut args = Vec::new();
122
123    for (i, arg) in bare_fn.inputs.iter().enumerate() {
124        let ident = match &arg.name {
125            Some((ident, _)) => ident.clone(),
126            None => format_ident!("arg{}", i),
127        };
128        args.push((ident, &arg.ty));
129    }
130
131    let parsed = ParsedSignature {
132        struct_ident,
133        args: args.clone(),
134        output: &bare_fn.output,
135    };
136
137    let abi_traits = generate_abi_traits(&parsed)?;
138
139    let mut fn_args = Vec::new();
140    let mut arg_conversions = Vec::new();
141    let mut call_args = Vec::new();
142    for (ident, ty) in &args {
143        fn_args.push(quote! { #ident: #ty });
144        let conversion = generate_conversion(ident, ty)?;
145        arg_conversions.push(conversion);
146        call_args.push(quote! { &#ident });
147    }
148
149    let args_len = call_args.len();
150    if args_len > 9 {
151        return Err(Error::new_spanned(
152            item_type,
153            "Functions with more than 9 arguments are not supported yet",
154        ));
155    }
156    let call_method_name = format_ident!("call{}", args_len);
157    let call_method = quote! { #call_method_name(&::wasm_bindgen::JsValue::NULL, #(#call_args),*) };
158
159    let output = parsed.output;
160    let (ret_type, ret_stmt) = match output {
161        ReturnType::Default => (quote! { () }, quote! { self.0.#call_method.map(|_| ()) }),
162        ReturnType::Type(_, ty) => {
163            let conversion = generate_return_conversion(ty)?;
164            (
165                quote! { #ty },
166                quote! {
167                    let res = self.0.#call_method?;
168                    #conversion
169                },
170            )
171        }
172    };
173
174    Ok(quote! {
175        pub struct #struct_ident(pub ::js_sys::Function);
176
177        impl #struct_ident {
178            pub fn call(&self, #(#fn_args),*) -> Result<#ret_type, ::wasm_bindgen::JsValue> {
179                #(#arg_conversions)*
180                #ret_stmt
181            }
182        }
183
184        #abi_traits
185    })
186}
187
188fn generate_conversion(ident: &Ident, ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
189    if let Type::ImplTrait(type_impl) = ty {
190        for bound in &type_impl.bounds {
191            if let syn::TypeParamBound::Trait(trait_bound) = bound
192                && let Some(segment) = trait_bound.path.segments.last()
193                && let PathArguments::AngleBracketed(args) = &segment.arguments
194                && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
195            {
196                match segment.ident.to_string().as_str() {
197                    "Into" => {
198                        let inner_conversion = generate_conversion(ident, inner_ty)?;
199                        return Ok(quote! {
200                            let #ident = ::std::convert::Into::<#inner_ty>::into(#ident);
201                            #inner_conversion
202                        });
203                    }
204                    "AsRef" => {
205                        if let Type::Slice(slice) = inner_ty {
206                            return Ok(generate_typed_array_conversion(ident, &slice.elem));
207                        }
208                    }
209                    _ => {}
210                }
211            }
212        }
213        return Err(Error::new_spanned(
214            ty,
215            "Unsupported `impl Trait`. Only `impl Into<T>` and `impl AsRef<[T]>` are supported.",
216        ));
217    }
218
219    if let Some(inner_ty) = get_slice_element_type(ty) {
220        Ok(generate_typed_array_conversion(ident, inner_ty))
221    } else {
222        Ok(quote! {
223            let #ident = ::std::convert::Into::<::wasm_bindgen::JsValue>::into(#ident);
224        })
225    }
226}
227
228fn generate_typed_array_conversion(ident: &Ident, inner_ty: &Type) -> proc_macro2::TokenStream {
229    if let Some(arr_type) = get_typed_array_ident(inner_ty) {
230        quote! {
231            let #ident = ::wasm_bindgen::JsValue::from(::js_sys::#arr_type::from(::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)));
232        }
233    } else {
234        quote! {
235            let #ident = ::wasm_bindgen::JsValue::from(
236                ::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)
237                    .iter()
238                    .map(::wasm_bindgen::JsValue::from)
239                    .collect::<::js_sys::Array>()
240            );
241        }
242    }
243}
244
245fn get_typed_array_ident(inner_ty: &Type) -> Option<proc_macro2::TokenStream> {
246    let inner_str = match inner_ty {
247        Type::Path(p) => p.path.segments.last().map(|s| s.ident.to_string()),
248        _ => None,
249    };
250
251    match inner_str.as_deref() {
252        Some("u8") => Some(quote! { Uint8Array }),
253        Some("i8") => Some(quote! { Int8Array }),
254        Some("u16") => Some(quote! { Uint16Array }),
255        Some("i16") => Some(quote! { Int16Array }),
256        Some("u32") => Some(quote! { Uint32Array }),
257        Some("i32") => Some(quote! { Int32Array }),
258        Some("f32") => Some(quote! { Float32Array }),
259        Some("f64") => Some(quote! { Float64Array }),
260        Some("u64") => Some(quote! { BigUint64Array }),
261        Some("i64") => Some(quote! { BigInt64Array }),
262        _ => None,
263    }
264}
265
266fn get_slice_element_type(ty: &Type) -> Option<&Type> {
267    match ty {
268        Type::Path(type_path) => {
269            let segment = type_path.path.segments.last()?;
270            // Types that implement AsRef<[T]> and we can easily extract T from AST
271            if matches!(
272                segment.ident.to_string().as_str(),
273                "Vec" | "Box" | "Arc" | "Rc"
274            ) && let PathArguments::AngleBracketed(args) = &segment.arguments
275                && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
276            {
277                if let Type::Slice(slice) = inner {
278                    return Some(&*slice.elem);
279                }
280                return Some(inner);
281            }
282        }
283        Type::Reference(type_ref) => {
284            if let Type::Slice(type_slice) = &*type_ref.elem {
285                return Some(&*type_slice.elem);
286            }
287            return get_slice_element_type(&type_ref.elem);
288        }
289        _ => {}
290    }
291    None
292}
293
294fn parse_item_impl(item_impl: &ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
295    if item_impl.trait_.is_some() {
296        return Err(Error::new_spanned(
297            item_impl,
298            "#[ts_function] cannot be applied to trait impls",
299        ));
300    }
301
302    let Type::Path(type_path) = &*item_impl.self_ty else {
303        return Err(Error::new_spanned(
304            &item_impl.self_ty,
305            "Expected a simple path for the struct",
306        ));
307    };
308
309    let struct_ident = type_path.path.get_ident().ok_or_else(|| {
310        Error::new_spanned(
311            &type_path.path,
312            "Expected a single identifier for the struct",
313        )
314    })?;
315
316    let method = item_impl
317        .items
318        .iter()
319        .find_map(|item| {
320            if let syn::ImplItem::Fn(method) = item
321                && method.sig.ident == "call"
322            {
323                return Some(method);
324            }
325            None
326        })
327        .ok_or_else(|| Error::new_spanned(item_impl, "Missing `call` method in impl block"))?;
328
329    let mut args = Vec::new();
330    let mut inputs_iter = method.sig.inputs.iter();
331
332    // Check first argument is `&self` or `&mut self`
333    match inputs_iter.next() {
334        Some(FnArg::Receiver(_)) => {}
335        _ => {
336            return Err(Error::new_spanned(
337                &method.sig,
338                "The `call` method must take `&self` or `&mut self` as its first parameter",
339            ));
340        }
341    }
342
343    for (i, arg) in inputs_iter.enumerate() {
344        let FnArg::Typed(pat_type) = arg else {
345            return Err(Error::new_spanned(arg, "Expected a typed argument"));
346        };
347
348        let ident = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
349            pat_ident.ident.clone()
350        } else {
351            format_ident!("arg{}", i)
352        };
353
354        args.push((ident, &*pat_type.ty));
355    }
356
357    let parsed = ParsedSignature {
358        struct_ident,
359        args,
360        output: &method.sig.output,
361    };
362
363    let abi_traits = generate_abi_traits(&parsed)?;
364
365    Ok(quote! {
366        #item_impl
367        #abi_traits
368    })
369}
370
371fn generate_abi_traits(parsed: &ParsedSignature) -> syn::Result<proc_macro2::TokenStream> {
372    let struct_ident = parsed.struct_ident;
373    let mut ts_args = Vec::new();
374
375    for (ident, ty) in &parsed.args {
376        let ts_ty = ty
377            .to_ts_type()
378            .map_err(|e| Error::new_spanned(ty, e.message))?
379            .to_string();
380        ts_args.push(format!("{}: {}", ident, ts_ty));
381    }
382
383    let ts_output = match parsed.output {
384        ReturnType::Default => "void".to_string(),
385        ReturnType::Type(_, ty) => ty
386            .to_ts_type()
387            .map_err(|e| Error::new_spanned(ty, e.message))?
388            .to_string(),
389    };
390
391    let ts_string = format!(
392        "type {} = ({}) => {};",
393        struct_ident,
394        ts_args.join(", "),
395        ts_output
396    );
397
398    let generated = quote! {
399        #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
400        const _: &'static str = #ts_string;
401
402        impl ::wasm_bindgen::describe::WasmDescribe for #struct_ident {
403            fn describe() {
404                <::js_sys::Function as ::wasm_bindgen::describe::WasmDescribe>::describe()
405            }
406        }
407
408        impl ::wasm_bindgen::convert::FromWasmAbi for #struct_ident {
409            type Abi = <::js_sys::Function as ::wasm_bindgen::convert::FromWasmAbi>::Abi;
410
411            unsafe fn from_abi(js: Self::Abi) -> Self {
412                Self(::js_sys::Function::from_abi(js))
413            }
414        }
415
416        impl ::wasm_bindgen::convert::OptionFromWasmAbi for #struct_ident {
417            fn is_none(abi: &Self::Abi) -> bool {
418                <::js_sys::Function as ::wasm_bindgen::convert::OptionFromWasmAbi>::is_none(abi)
419            }
420        }
421
422        impl From<::js_sys::Function> for #struct_ident {
423            fn from(f: ::js_sys::Function) -> Self {
424                Self(f)
425            }
426        }
427    };
428
429    Ok(generated)
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use syn::parse_quote;
436
437    #[test]
438    fn test_item_type() {
439        let item_type: ItemType = parse_quote! {
440            pub type OnClick = fn(x: f64, y: impl Into<f64>, arr: js_sys::Float64Array);
441        };
442        let result = parse_item_type(&item_type).unwrap();
443        let result_str = result.to_string();
444
445        assert!(
446            result_str
447                .contains("type OnClick = (x: number, y: number, arr: Float64Array) => void;")
448        );
449        assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
450        assert!(result_str.contains(
451            "pub fn call (& self , x : f64 , y : impl Into < f64 > , arr : js_sys :: Float64Array)"
452        ));
453    }
454
455    #[test]
456    fn test_item_impl() {
457        let item_impl: ItemImpl = parse_quote! {
458            impl OnScroll {
459                pub fn call(&self, y: f64) {
460                    // body
461                }
462            }
463        };
464        let result = parse_item_impl(&item_impl).unwrap();
465        let result_str = result.to_string();
466
467        assert!(result_str.contains("type OnScroll = (y: number) => void;"));
468        assert!(
469            result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
470        );
471    }
472
473    #[test]
474    fn test_recursive_generics() {
475        let item_type: ItemType = parse_quote! {
476            pub type ResultCb = fn(res: Result<String, i32>);
477        };
478        let result = parse_item_type(&item_type).unwrap();
479        let result_str = result.to_string();
480
481        assert!(result_str.contains("type ResultCb = (res: Result<string, number>) => void;"));
482
483        let item_type: ItemType = parse_quote! {
484            pub type NestedVecCb = fn(args: Vec<Vec<f64>>);
485        };
486        let result = parse_item_type(&item_type).unwrap();
487        let result_str = result.to_string();
488
489        assert!(result_str.contains("type NestedVecCb = (args: Float64Array[]) => void;"));
490    }
491}