1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 Error, FnArg, GenericArgument, Ident, Item, ItemImpl, ItemType, PathArguments, ReturnType,
5 Type, parse_macro_input,
6};
7
8#[macro_use]
9mod ts_type;
10mod ts_macro;
11
12use crate::ts_type::ToTsType;
13
14#[proc_macro_attribute]
34pub fn ts(attr: TokenStream, input: TokenStream) -> TokenStream {
35 ts_macro::ts(attr, input)
36}
37
38struct ParsedSignature<'a> {
39 struct_ident: &'a Ident,
40 args: Vec<(Ident, &'a Type)>,
41 output: &'a ReturnType,
42}
43
44#[proc_macro_attribute]
90pub fn ts_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
91 let item = parse_macro_input!(item as Item);
92
93 let result = match &item {
94 Item::Type(item_type) => parse_item_type(item_type),
95 Item::Impl(item_impl) => parse_item_impl(item_impl),
96 _ => {
97 return Error::new_spanned(
98 item,
99 "#[ts_function] can only be applied to a type alias or an impl block",
100 )
101 .to_compile_error()
102 .into();
103 }
104 };
105
106 match result {
107 Ok(tokens) => tokens.into(),
108 Err(err) => err.to_compile_error().into(),
109 }
110}
111
112fn generate_return_conversion(ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
113 match ty {
114 Type::Path(type_path) => {
115 let segment = type_path.path.segments.last().unwrap();
116 let ident = &segment.ident;
117 let ident_str = ident.to_string();
118
119 if let Some(inner_ty) = get_slice_element_type(ty)
120 && let Some(arr_type) = get_typed_array_ident(inner_ty)
121 {
122 return Ok(quote! {
123 let arr: ::js_sys::#arr_type = ::wasm_bindgen::JsCast::unchecked_into(res);
124 Ok(::std::convert::Into::<#ty>::into(arr.to_vec()))
125 });
126 }
127
128 match ident_str.as_str() {
129 "f32" | "f64" | "i8" | "i16" | "i32" | "u8" | "u16" | "u32" => Ok(quote! {
130 res.as_f64().map(|v| v as #ty).ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a number"))
131 }),
132 "i64" | "u64" => Ok(quote! {
133 ::std::convert::TryInto::<#ty>::try_into(res).map_err(|_| ::wasm_bindgen::JsValue::from_str("Expected a BigInt"))
134 }),
135 "bool" => Ok(quote! {
136 res.as_bool().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a boolean"))
137 }),
138 "String" => Ok(quote! {
139 res.as_string().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a string"))
140 }),
141 "JsValue" => Ok(quote! {
142 Ok(res)
143 }),
144 "Option" => {
145 let PathArguments::AngleBracketed(args) = &segment.arguments else {
146 return Err(Error::new_spanned(
147 ty,
148 "Expected generic argument for Option",
149 ));
150 };
151 let syn::GenericArgument::Type(inner_ty) = args.args.first().unwrap() else {
152 return Err(Error::new_spanned(ty, "Expected type argument for Option"));
153 };
154 let inner_conversion = generate_return_conversion(inner_ty)?;
155 Ok(quote! {
156 if res.is_null() || res.is_undefined() {
157 Ok(None)
158 } else {
159 let res = { #inner_conversion };
160 res.map(Some)
161 }
162 })
163 }
164 _ => Ok(quote! {
165 Ok(::wasm_bindgen::JsCast::unchecked_into::<#ty>(res))
166 }),
167 }
168 }
169 _ => Err(Error::new_spanned(
170 ty,
171 "Unsupported return type in type alias pattern. Use the `impl` escape hatch instead.",
172 )),
173 }
174}
175
176fn parse_item_type(item_type: &ItemType) -> syn::Result<proc_macro2::TokenStream> {
177 let Type::BareFn(bare_fn) = &*item_type.ty else {
178 return Err(Error::new_spanned(
179 &item_type.ty,
180 "Expected a function pointer type (e.g., `fn(x: f64)`)",
181 ));
182 };
183
184 let struct_ident = &item_type.ident;
185 let mut args = Vec::new();
186
187 for (i, arg) in bare_fn.inputs.iter().enumerate() {
188 let ident = match &arg.name {
189 Some((ident, _)) => ident.clone(),
190 None => format_ident!("arg{}", i),
191 };
192 args.push((ident, &arg.ty));
193 }
194
195 let parsed = ParsedSignature {
196 struct_ident,
197 args: args.clone(),
198 output: &bare_fn.output,
199 };
200
201 let abi_traits = generate_abi_traits(&parsed)?;
202
203 let mut fn_args = Vec::new();
204 let mut arg_conversions = Vec::new();
205 let mut call_args = Vec::new();
206 for (ident, ty) in &args {
207 fn_args.push(quote! { #ident: #ty });
208 let conversion = generate_conversion(ident, ty)?;
209 arg_conversions.push(conversion);
210 call_args.push(quote! { &#ident });
211 }
212
213 let args_len = call_args.len();
214 if args_len > 9 {
215 return Err(Error::new_spanned(
216 item_type,
217 "Functions with more than 9 arguments are not supported yet",
218 ));
219 }
220 let call_method_name = format_ident!("call{}", args_len);
221 let call_method = quote! { #call_method_name(&::wasm_bindgen::JsValue::NULL, #(#call_args),*) };
222
223 let output = parsed.output;
224 let (ret_type, ret_stmt) = match output {
225 ReturnType::Default => (quote! { () }, quote! { self.0.#call_method.map(|_| ()) }),
226 ReturnType::Type(_, ty) => {
227 let conversion = generate_return_conversion(ty)?;
228 (
229 quote! { #ty },
230 quote! {
231 let res = self.0.#call_method?;
232 #conversion
233 },
234 )
235 }
236 };
237
238 Ok(quote! {
239 pub struct #struct_ident(pub ::js_sys::Function);
240
241 impl #struct_ident {
242 pub fn call(&self, #(#fn_args),*) -> Result<#ret_type, ::wasm_bindgen::JsValue> {
243 #(#arg_conversions)*
244 #ret_stmt
245 }
246 }
247
248 #abi_traits
249 })
250}
251
252fn generate_conversion(ident: &Ident, ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
253 if let Type::ImplTrait(type_impl) = ty {
254 for bound in &type_impl.bounds {
255 if let syn::TypeParamBound::Trait(trait_bound) = bound
256 && let Some(segment) = trait_bound.path.segments.last()
257 && let PathArguments::AngleBracketed(args) = &segment.arguments
258 && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
259 {
260 match segment.ident.to_string().as_str() {
261 "Into" => {
262 let inner_conversion = generate_conversion(ident, inner_ty)?;
263 return Ok(quote! {
264 let #ident = ::std::convert::Into::<#inner_ty>::into(#ident);
265 #inner_conversion
266 });
267 }
268 "AsRef" => {
269 if let Type::Slice(slice) = inner_ty {
270 return Ok(generate_typed_array_conversion(ident, &slice.elem));
271 }
272 }
273 _ => {}
274 }
275 }
276 }
277 return Err(Error::new_spanned(
278 ty,
279 "Unsupported `impl Trait`. Only `impl Into<T>` and `impl AsRef<[T]>` are supported.",
280 ));
281 }
282
283 if let Some(inner_ty) = get_slice_element_type(ty) {
284 Ok(generate_typed_array_conversion(ident, inner_ty))
285 } else {
286 Ok(quote! {
287 let #ident = ::std::convert::Into::<::wasm_bindgen::JsValue>::into(#ident);
288 })
289 }
290}
291
292fn generate_typed_array_conversion(ident: &Ident, inner_ty: &Type) -> proc_macro2::TokenStream {
293 if let Some(arr_type) = get_typed_array_ident(inner_ty) {
294 quote! {
295 let #ident = ::wasm_bindgen::JsValue::from(::js_sys::#arr_type::from(::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)));
296 }
297 } else {
298 quote! {
299 let #ident = ::wasm_bindgen::JsValue::from(
300 ::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)
301 .iter()
302 .map(::wasm_bindgen::JsValue::from)
303 .collect::<::js_sys::Array>()
304 );
305 }
306 }
307}
308
309fn get_typed_array_ident(inner_ty: &Type) -> Option<proc_macro2::TokenStream> {
310 let inner_str = match inner_ty {
311 Type::Path(p) => p.path.segments.last().map(|s| s.ident.to_string()),
312 _ => None,
313 };
314
315 match inner_str.as_deref() {
316 Some("u8") => Some(quote! { Uint8Array }),
317 Some("i8") => Some(quote! { Int8Array }),
318 Some("u16") => Some(quote! { Uint16Array }),
319 Some("i16") => Some(quote! { Int16Array }),
320 Some("u32") => Some(quote! { Uint32Array }),
321 Some("i32") => Some(quote! { Int32Array }),
322 Some("f32") => Some(quote! { Float32Array }),
323 Some("f64") => Some(quote! { Float64Array }),
324 Some("u64") => Some(quote! { BigUint64Array }),
325 Some("i64") => Some(quote! { BigInt64Array }),
326 _ => None,
327 }
328}
329
330fn get_slice_element_type(ty: &Type) -> Option<&Type> {
331 match ty {
332 Type::Path(type_path) => {
333 let segment = type_path.path.segments.last()?;
334 if matches!(
336 segment.ident.to_string().as_str(),
337 "Vec" | "Box" | "Arc" | "Rc"
338 ) && let PathArguments::AngleBracketed(args) = &segment.arguments
339 && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
340 {
341 if let Type::Slice(slice) = inner {
342 return Some(&*slice.elem);
343 }
344 return Some(inner);
345 }
346 }
347 Type::Reference(type_ref) => {
348 if let Type::Slice(type_slice) = &*type_ref.elem {
349 return Some(&*type_slice.elem);
350 }
351 return get_slice_element_type(&type_ref.elem);
352 }
353 _ => {}
354 }
355 None
356}
357
358fn parse_item_impl(item_impl: &ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
359 if item_impl.trait_.is_some() {
360 return Err(Error::new_spanned(
361 item_impl,
362 "#[ts_function] cannot be applied to trait impls",
363 ));
364 }
365
366 let Type::Path(type_path) = &*item_impl.self_ty else {
367 return Err(Error::new_spanned(
368 &item_impl.self_ty,
369 "Expected a simple path for the struct",
370 ));
371 };
372
373 let struct_ident = type_path.path.get_ident().ok_or_else(|| {
374 Error::new_spanned(
375 &type_path.path,
376 "Expected a single identifier for the struct",
377 )
378 })?;
379
380 let method = item_impl
381 .items
382 .iter()
383 .find_map(|item| {
384 if let syn::ImplItem::Fn(method) = item
385 && method.sig.ident == "call"
386 {
387 return Some(method);
388 }
389 None
390 })
391 .ok_or_else(|| Error::new_spanned(item_impl, "Missing `call` method in impl block"))?;
392
393 let mut args = Vec::new();
394 let mut inputs_iter = method.sig.inputs.iter();
395
396 match inputs_iter.next() {
398 Some(FnArg::Receiver(_)) => {}
399 _ => {
400 return Err(Error::new_spanned(
401 &method.sig,
402 "The `call` method must take `&self` or `&mut self` as its first parameter",
403 ));
404 }
405 }
406
407 for (i, arg) in inputs_iter.enumerate() {
408 let FnArg::Typed(pat_type) = arg else {
409 return Err(Error::new_spanned(arg, "Expected a typed argument"));
410 };
411
412 let ident = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
413 pat_ident.ident.clone()
414 } else {
415 format_ident!("arg{}", i)
416 };
417
418 args.push((ident, &*pat_type.ty));
419 }
420
421 let parsed = ParsedSignature {
422 struct_ident,
423 args,
424 output: &method.sig.output,
425 };
426
427 let abi_traits = generate_abi_traits(&parsed)?;
428
429 Ok(quote! {
430 #item_impl
431 #abi_traits
432 })
433}
434
435fn generate_abi_traits(parsed: &ParsedSignature) -> syn::Result<proc_macro2::TokenStream> {
436 let struct_ident = parsed.struct_ident;
437 let mut ts_args = Vec::new();
438
439 for (ident, ty) in &parsed.args {
440 let ts_ty = ty
441 .to_ts_type()
442 .map_err(|e| Error::new_spanned(ty, e.message))?
443 .to_string();
444 ts_args.push(format!("{}: {}", ident, ts_ty));
445 }
446
447 let ts_output = match parsed.output {
448 ReturnType::Default => "void".to_string(),
449 ReturnType::Type(_, ty) => ty
450 .to_ts_type()
451 .map_err(|e| Error::new_spanned(ty, e.message))?
452 .to_string(),
453 };
454
455 let ts_string = format!(
456 "type {} = ({}) => {};",
457 struct_ident,
458 ts_args.join(", "),
459 ts_output
460 );
461
462 let generated = quote! {
463 #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
464 const _: &'static str = #ts_string;
465
466 impl ::wasm_bindgen::describe::WasmDescribe for #struct_ident {
467 fn describe() {
468 <::js_sys::Function as ::wasm_bindgen::describe::WasmDescribe>::describe()
469 }
470 }
471
472 impl ::wasm_bindgen::convert::FromWasmAbi for #struct_ident {
473 type Abi = <::js_sys::Function as ::wasm_bindgen::convert::FromWasmAbi>::Abi;
474
475 unsafe fn from_abi(js: Self::Abi) -> Self {
476 Self(::js_sys::Function::from_abi(js))
477 }
478 }
479
480 impl ::wasm_bindgen::convert::OptionFromWasmAbi for #struct_ident {
481 fn is_none(abi: &Self::Abi) -> bool {
482 <::js_sys::Function as ::wasm_bindgen::convert::OptionFromWasmAbi>::is_none(abi)
483 }
484 }
485
486 impl From<::js_sys::Function> for #struct_ident {
487 fn from(f: ::js_sys::Function) -> Self {
488 Self(f)
489 }
490 }
491 };
492
493 Ok(generated)
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use syn::parse_quote;
500
501 #[test]
502 fn test_item_type() {
503 let item_type: ItemType = parse_quote! {
504 pub type OnClick = fn(x: f64, y: impl Into<f64>, arr: js_sys::Float64Array);
505 };
506 let result = parse_item_type(&item_type).unwrap();
507 let result_str = result.to_string();
508
509 assert!(
510 result_str
511 .contains("type OnClick = (x: number, y: number, arr: Float64Array) => void;")
512 );
513 assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
514 assert!(result_str.contains(
515 "pub fn call (& self , x : f64 , y : impl Into < f64 > , arr : js_sys :: Float64Array)"
516 ));
517 }
518
519 #[test]
520 fn test_item_impl() {
521 let item_impl: ItemImpl = parse_quote! {
522 impl OnScroll {
523 pub fn call(&self, y: f64) {
524 }
526 }
527 };
528 let result = parse_item_impl(&item_impl).unwrap();
529 let result_str = result.to_string();
530
531 assert!(result_str.contains("type OnScroll = (y: number) => void;"));
532 assert!(
533 result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
534 );
535 }
536
537 #[test]
538 fn test_recursive_generics() {
539 let item_type: ItemType = parse_quote! {
540 pub type ResultCb = fn(res: Result<String, i32>);
541 };
542 let result = parse_item_type(&item_type).unwrap();
543 let result_str = result.to_string();
544
545 assert!(result_str.contains("type ResultCb = (res: Result<string, number>) => void;"));
546
547 let item_type: ItemType = parse_quote! {
548 pub type NestedVecCb = fn(args: Vec<Vec<f64>>);
549 };
550 let result = parse_item_type(&item_type).unwrap();
551 let result_str = result.to_string();
552
553 assert!(result_str.contains("type NestedVecCb = (args: Float64Array[]) => void;"));
554 }
555}