1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 Error, FnArg, GenericArgument, Ident, Item, ItemImpl, ItemType, PathArguments, ReturnType,
5 Type, parse_macro_input,
6};
7
8#[macro_use]
9mod ts_type;
10mod ts_macro;
11
12use crate::ts_type::ToTsType;
13
14#[proc_macro_attribute]
15pub fn ts(attr: TokenStream, input: TokenStream) -> TokenStream {
16 ts_macro::ts(attr, input)
17}
18
19struct ParsedSignature<'a> {
20 struct_ident: &'a Ident,
21 args: Vec<(Ident, &'a Type)>,
22 output: &'a ReturnType,
23}
24
25#[proc_macro_attribute]
26pub fn ts_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
27 let item = parse_macro_input!(item as Item);
28
29 let result = match &item {
30 Item::Type(item_type) => parse_item_type(item_type),
31 Item::Impl(item_impl) => parse_item_impl(item_impl),
32 _ => {
33 return Error::new_spanned(
34 item,
35 "#[ts_function] can only be applied to a type alias or an impl block",
36 )
37 .to_compile_error()
38 .into();
39 }
40 };
41
42 match result {
43 Ok(tokens) => tokens.into(),
44 Err(err) => err.to_compile_error().into(),
45 }
46}
47
48fn generate_return_conversion(ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
49 match ty {
50 Type::Path(type_path) => {
51 let segment = type_path.path.segments.last().unwrap();
52 let ident = &segment.ident;
53 let ident_str = ident.to_string();
54
55 if let Some(inner_ty) = get_slice_element_type(ty)
56 && let Some(arr_type) = get_typed_array_ident(inner_ty)
57 {
58 return Ok(quote! {
59 let arr: ::js_sys::#arr_type = ::wasm_bindgen::JsCast::unchecked_into(res);
60 Ok(::std::convert::Into::<#ty>::into(arr.to_vec()))
61 });
62 }
63
64 match ident_str.as_str() {
65 "f32" | "f64" | "i8" | "i16" | "i32" | "u8" | "u16" | "u32" => Ok(quote! {
66 res.as_f64().map(|v| v as #ty).ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a number"))
67 }),
68 "i64" | "u64" => Ok(quote! {
69 ::std::convert::TryInto::<#ty>::try_into(res).map_err(|_| ::wasm_bindgen::JsValue::from_str("Expected a BigInt"))
70 }),
71 "bool" => Ok(quote! {
72 res.as_bool().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a boolean"))
73 }),
74 "String" => Ok(quote! {
75 res.as_string().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a string"))
76 }),
77 "JsValue" => Ok(quote! {
78 Ok(res)
79 }),
80 "Option" => {
81 let PathArguments::AngleBracketed(args) = &segment.arguments else {
82 return Err(Error::new_spanned(
83 ty,
84 "Expected generic argument for Option",
85 ));
86 };
87 let syn::GenericArgument::Type(inner_ty) = args.args.first().unwrap() else {
88 return Err(Error::new_spanned(ty, "Expected type argument for Option"));
89 };
90 let inner_conversion = generate_return_conversion(inner_ty)?;
91 Ok(quote! {
92 if res.is_null() || res.is_undefined() {
93 Ok(None)
94 } else {
95 let res = { #inner_conversion };
96 res.map(Some)
97 }
98 })
99 }
100 _ => Ok(quote! {
101 Ok(::wasm_bindgen::JsCast::unchecked_into::<#ty>(res))
102 }),
103 }
104 }
105 _ => Err(Error::new_spanned(
106 ty,
107 "Unsupported return type in type alias pattern. Use the `impl` escape hatch instead.",
108 )),
109 }
110}
111
112fn parse_item_type(item_type: &ItemType) -> syn::Result<proc_macro2::TokenStream> {
113 let Type::BareFn(bare_fn) = &*item_type.ty else {
114 return Err(Error::new_spanned(
115 &item_type.ty,
116 "Expected a function pointer type (e.g., `fn(x: f64)`)",
117 ));
118 };
119
120 let struct_ident = &item_type.ident;
121 let mut args = Vec::new();
122
123 for (i, arg) in bare_fn.inputs.iter().enumerate() {
124 let ident = match &arg.name {
125 Some((ident, _)) => ident.clone(),
126 None => format_ident!("arg{}", i),
127 };
128 args.push((ident, &arg.ty));
129 }
130
131 let parsed = ParsedSignature {
132 struct_ident,
133 args: args.clone(),
134 output: &bare_fn.output,
135 };
136
137 let abi_traits = generate_abi_traits(&parsed)?;
138
139 let mut fn_args = Vec::new();
140 let mut arg_conversions = Vec::new();
141 let mut call_args = Vec::new();
142 for (ident, ty) in &args {
143 fn_args.push(quote! { #ident: #ty });
144 let conversion = generate_conversion(ident, ty)?;
145 arg_conversions.push(conversion);
146 call_args.push(quote! { &#ident });
147 }
148
149 let args_len = call_args.len();
150 if args_len > 9 {
151 return Err(Error::new_spanned(
152 item_type,
153 "Functions with more than 9 arguments are not supported yet",
154 ));
155 }
156 let call_method_name = format_ident!("call{}", args_len);
157 let call_method = quote! { #call_method_name(&::wasm_bindgen::JsValue::NULL, #(#call_args),*) };
158
159 let output = parsed.output;
160 let (ret_type, ret_stmt) = match output {
161 ReturnType::Default => (quote! { () }, quote! { self.0.#call_method.map(|_| ()) }),
162 ReturnType::Type(_, ty) => {
163 let conversion = generate_return_conversion(ty)?;
164 (
165 quote! { #ty },
166 quote! {
167 let res = self.0.#call_method?;
168 #conversion
169 },
170 )
171 }
172 };
173
174 Ok(quote! {
175 pub struct #struct_ident(pub ::js_sys::Function);
176
177 impl #struct_ident {
178 pub fn call(&self, #(#fn_args),*) -> Result<#ret_type, ::wasm_bindgen::JsValue> {
179 #(#arg_conversions)*
180 #ret_stmt
181 }
182 }
183
184 #abi_traits
185 })
186}
187
188fn generate_conversion(ident: &Ident, ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
189 if let Type::ImplTrait(type_impl) = ty {
190 for bound in &type_impl.bounds {
191 if let syn::TypeParamBound::Trait(trait_bound) = bound
192 && let Some(segment) = trait_bound.path.segments.last()
193 && let PathArguments::AngleBracketed(args) = &segment.arguments
194 && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
195 {
196 match segment.ident.to_string().as_str() {
197 "Into" => {
198 let inner_conversion = generate_conversion(ident, inner_ty)?;
199 return Ok(quote! {
200 let #ident = ::std::convert::Into::<#inner_ty>::into(#ident);
201 #inner_conversion
202 });
203 }
204 "AsRef" => {
205 if let Type::Slice(slice) = inner_ty {
206 return Ok(generate_typed_array_conversion(ident, &slice.elem));
207 }
208 }
209 _ => {}
210 }
211 }
212 }
213 return Err(Error::new_spanned(
214 ty,
215 "Unsupported `impl Trait`. Only `impl Into<T>` and `impl AsRef<[T]>` are supported.",
216 ));
217 }
218
219 if let Some(inner_ty) = get_slice_element_type(ty) {
220 Ok(generate_typed_array_conversion(ident, inner_ty))
221 } else {
222 Ok(quote! {
223 let #ident = ::std::convert::Into::<::wasm_bindgen::JsValue>::into(#ident);
224 })
225 }
226}
227
228fn generate_typed_array_conversion(ident: &Ident, inner_ty: &Type) -> proc_macro2::TokenStream {
229 if let Some(arr_type) = get_typed_array_ident(inner_ty) {
230 quote! {
231 let #ident = ::wasm_bindgen::JsValue::from(::js_sys::#arr_type::from(::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)));
232 }
233 } else {
234 quote! {
235 let #ident = ::wasm_bindgen::JsValue::from(
236 ::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)
237 .iter()
238 .map(::wasm_bindgen::JsValue::from)
239 .collect::<::js_sys::Array>()
240 );
241 }
242 }
243}
244
245fn get_typed_array_ident(inner_ty: &Type) -> Option<proc_macro2::TokenStream> {
246 let inner_str = match inner_ty {
247 Type::Path(p) => p.path.segments.last().map(|s| s.ident.to_string()),
248 _ => None,
249 };
250
251 match inner_str.as_deref() {
252 Some("u8") => Some(quote! { Uint8Array }),
253 Some("i8") => Some(quote! { Int8Array }),
254 Some("u16") => Some(quote! { Uint16Array }),
255 Some("i16") => Some(quote! { Int16Array }),
256 Some("u32") => Some(quote! { Uint32Array }),
257 Some("i32") => Some(quote! { Int32Array }),
258 Some("f32") => Some(quote! { Float32Array }),
259 Some("f64") => Some(quote! { Float64Array }),
260 Some("u64") => Some(quote! { BigUint64Array }),
261 Some("i64") => Some(quote! { BigInt64Array }),
262 _ => None,
263 }
264}
265
266fn get_slice_element_type(ty: &Type) -> Option<&Type> {
267 match ty {
268 Type::Path(type_path) => {
269 let segment = type_path.path.segments.last()?;
270 if matches!(
272 segment.ident.to_string().as_str(),
273 "Vec" | "Box" | "Arc" | "Rc"
274 ) && let PathArguments::AngleBracketed(args) = &segment.arguments
275 && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
276 {
277 if let Type::Slice(slice) = inner {
278 return Some(&*slice.elem);
279 }
280 return Some(inner);
281 }
282 }
283 Type::Reference(type_ref) => {
284 if let Type::Slice(type_slice) = &*type_ref.elem {
285 return Some(&*type_slice.elem);
286 }
287 return get_slice_element_type(&type_ref.elem);
288 }
289 _ => {}
290 }
291 None
292}
293
294fn parse_item_impl(item_impl: &ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
295 if item_impl.trait_.is_some() {
296 return Err(Error::new_spanned(
297 item_impl,
298 "#[ts_function] cannot be applied to trait impls",
299 ));
300 }
301
302 let Type::Path(type_path) = &*item_impl.self_ty else {
303 return Err(Error::new_spanned(
304 &item_impl.self_ty,
305 "Expected a simple path for the struct",
306 ));
307 };
308
309 let struct_ident = type_path.path.get_ident().ok_or_else(|| {
310 Error::new_spanned(
311 &type_path.path,
312 "Expected a single identifier for the struct",
313 )
314 })?;
315
316 let method = item_impl
317 .items
318 .iter()
319 .find_map(|item| {
320 if let syn::ImplItem::Fn(method) = item
321 && method.sig.ident == "call"
322 {
323 return Some(method);
324 }
325 None
326 })
327 .ok_or_else(|| Error::new_spanned(item_impl, "Missing `call` method in impl block"))?;
328
329 let mut args = Vec::new();
330 let mut inputs_iter = method.sig.inputs.iter();
331
332 match inputs_iter.next() {
334 Some(FnArg::Receiver(_)) => {}
335 _ => {
336 return Err(Error::new_spanned(
337 &method.sig,
338 "The `call` method must take `&self` or `&mut self` as its first parameter",
339 ));
340 }
341 }
342
343 for (i, arg) in inputs_iter.enumerate() {
344 let FnArg::Typed(pat_type) = arg else {
345 return Err(Error::new_spanned(arg, "Expected a typed argument"));
346 };
347
348 let ident = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
349 pat_ident.ident.clone()
350 } else {
351 format_ident!("arg{}", i)
352 };
353
354 args.push((ident, &*pat_type.ty));
355 }
356
357 let parsed = ParsedSignature {
358 struct_ident,
359 args,
360 output: &method.sig.output,
361 };
362
363 let abi_traits = generate_abi_traits(&parsed)?;
364
365 Ok(quote! {
366 #item_impl
367 #abi_traits
368 })
369}
370
371fn generate_abi_traits(parsed: &ParsedSignature) -> syn::Result<proc_macro2::TokenStream> {
372 let struct_ident = parsed.struct_ident;
373 let mut ts_args = Vec::new();
374
375 for (ident, ty) in &parsed.args {
376 let ts_ty = ty
377 .to_ts_type()
378 .map_err(|e| Error::new_spanned(ty, e.message))?
379 .to_string();
380 ts_args.push(format!("{}: {}", ident, ts_ty));
381 }
382
383 let ts_output = match parsed.output {
384 ReturnType::Default => "void".to_string(),
385 ReturnType::Type(_, ty) => ty
386 .to_ts_type()
387 .map_err(|e| Error::new_spanned(ty, e.message))?
388 .to_string(),
389 };
390
391 let ts_string = format!(
392 "type {} = ({}) => {};",
393 struct_ident,
394 ts_args.join(", "),
395 ts_output
396 );
397
398 let generated = quote! {
399 #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
400 const _: &'static str = #ts_string;
401
402 impl ::wasm_bindgen::describe::WasmDescribe for #struct_ident {
403 fn describe() {
404 <::js_sys::Function as ::wasm_bindgen::describe::WasmDescribe>::describe()
405 }
406 }
407
408 impl ::wasm_bindgen::convert::FromWasmAbi for #struct_ident {
409 type Abi = <::js_sys::Function as ::wasm_bindgen::convert::FromWasmAbi>::Abi;
410
411 unsafe fn from_abi(js: Self::Abi) -> Self {
412 Self(::js_sys::Function::from_abi(js))
413 }
414 }
415
416 impl ::wasm_bindgen::convert::OptionFromWasmAbi for #struct_ident {
417 fn is_none(abi: &Self::Abi) -> bool {
418 <::js_sys::Function as ::wasm_bindgen::convert::OptionFromWasmAbi>::is_none(abi)
419 }
420 }
421
422 impl From<::js_sys::Function> for #struct_ident {
423 fn from(f: ::js_sys::Function) -> Self {
424 Self(f)
425 }
426 }
427 };
428
429 Ok(generated)
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use syn::parse_quote;
436
437 #[test]
438 fn test_item_type() {
439 let item_type: ItemType = parse_quote! {
440 pub type OnClick = fn(x: f64, y: impl Into<f64>, arr: js_sys::Float64Array);
441 };
442 let result = parse_item_type(&item_type).unwrap();
443 let result_str = result.to_string();
444
445 assert!(
446 result_str
447 .contains("type OnClick = (x: number, y: number, arr: Float64Array) => void;")
448 );
449 assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
450 assert!(result_str.contains(
451 "pub fn call (& self , x : f64 , y : impl Into < f64 > , arr : js_sys :: Float64Array)"
452 ));
453 }
454
455 #[test]
456 fn test_item_impl() {
457 let item_impl: ItemImpl = parse_quote! {
458 impl OnScroll {
459 pub fn call(&self, y: f64) {
460 }
462 }
463 };
464 let result = parse_item_impl(&item_impl).unwrap();
465 let result_str = result.to_string();
466
467 assert!(result_str.contains("type OnScroll = (y: number) => void;"));
468 assert!(
469 result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
470 );
471 }
472
473 #[test]
474 fn test_recursive_generics() {
475 let item_type: ItemType = parse_quote! {
476 pub type ResultCb = fn(res: Result<String, i32>);
477 };
478 let result = parse_item_type(&item_type).unwrap();
479 let result_str = result.to_string();
480
481 assert!(result_str.contains("type ResultCb = (res: Result<string, number>) => void;"));
482
483 let item_type: ItemType = parse_quote! {
484 pub type NestedVecCb = fn(args: Vec<Vec<f64>>);
485 };
486 let result = parse_item_type(&item_type).unwrap();
487 let result_str = result.to_string();
488
489 assert!(result_str.contains("type NestedVecCb = (args: Float64Array[]) => void;"));
490 }
491}