trait_ffi/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{Ident, ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
8
9macro_rules! bail {
10    ($i:expr, $msg:expr) => {
11        return syn::parse::Error::new($i, $msg).to_compile_error().into();
12    };
13}
14
15fn get_crate_name() -> String {
16    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
17}
18
19fn get_crate_version() -> String {
20    std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string())
21}
22
23fn prefix_version() -> String {
24    let version = lenient_semver::parse(&get_crate_version()).unwrap();
25    let major = version.major;
26    let minor = version.minor;
27    if major == 0 {
28        format!("0_{minor}")
29    } else {
30        major.to_string()
31    }
32}
33
34fn extern_fn_name(crate_name: &str, fn_name: &Ident) -> Ident {
35    let crate_name = crate_name.to_lowercase().replace("-", "_");
36    // let version = prefix_version();
37
38    format_ident!("__{crate_name}_{fn_name}")
39}
40
41fn parse_def_extern_trait_args(
42    args: TokenStream,
43) -> Result<(String, bool, Option<String>), String> {
44    if args.is_empty() {
45        return Ok(("rust".to_string(), false, None)); // 默认使用 Rust ABI,默认生成 impl_trait! 宏,无自定义模块路径
46    }
47
48    let args_str = args.to_string();
49    let mut abi = None;
50    let mut not_def_impl = false;
51    let mut mod_path = None;
52
53    // 简单解析 abi="value"、not_def_impl 和 mod_path="value" 形式
54    let parts: Vec<&str> = args_str.split(',').collect();
55
56    for part in parts {
57        let part = part.trim();
58        if part.starts_with("abi")
59            && let Some(start) = part.find('"')
60            && let Some(end) = part.rfind('"')
61            && start < end
62        {
63            abi = Some(part[start + 1..end].to_string());
64        } else if part.starts_with("mod_path")
65            && let Some(start) = part.find('"')
66            && let Some(end) = part.rfind('"')
67            && start < end
68        {
69            mod_path = Some(part[start + 1..end].to_string());
70        } else if part == "not_def_impl" {
71            not_def_impl = true;
72        }
73    }
74
75    let abi = abi.unwrap_or_else(|| "rust".to_string());
76
77    if abi != "c" && abi != "rust" {
78        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
79    }
80
81    Ok((abi, not_def_impl, mod_path))
82}
83
84/// Defines an extern trait that can be called across FFI boundaries.
85///
86/// This macro converts a regular Rust trait into a trait that can be called through FFI.
87/// It generates:
88/// 1. The original trait definition
89/// 2. A module containing wrapper functions that call external implementations
90/// 3. Optionally, a helper macro `impl_trait!` for implementing the trait (unless `not_def_impl` is specified)
91/// 4. A checker function to ensure the trait is properly implemented
92///
93/// # Arguments
94/// - `abi`: Optional parameter specifying ABI type ("c" or "rust"), defaults to "rust"
95/// - `not_def_impl`: Optional parameter to skip generating the `impl_trait!` macro
96///
97/// # Example
98/// ```rust
99/// #[def_extern_trait(abi = "c")]
100/// trait Calculator {
101///     fn add(&self, a: i32, b: i32) -> i32;
102///     fn multiply(&self, a: i32, b: i32) -> i32;
103/// }
104///
105/// // Skip generating impl_trait! macro
106/// #[def_extern_trait(abi = "c", not_def_impl)]
107/// trait Calculator2 {
108///     fn add(&self, a: i32, b: i32) -> i32;
109/// }
110/// ```
111///
112/// This will generate a `calculator` module containing functions that can call external implementations.
113#[proc_macro_attribute]
114pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
115    let (abi, not_def_impl, _mod_path) = match parse_def_extern_trait_args(args) {
116        Ok((abi, not_def_impl, mod_path)) => (abi, not_def_impl, mod_path),
117        Err(error_msg) => {
118            bail!(Span::call_site(), error_msg);
119        }
120    };
121
122    let input = parse_macro_input!(input as ItemTrait);
123    let vis = input.vis.clone();
124    let mod_name = format_ident!(
125        "{}",
126        input.ident.to_string().to_case(convert_case::Case::Snake)
127    );
128    let crate_name_str = get_crate_name();
129
130    let mut fn_list = vec![];
131    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
132    let mut crate_path_tokens = quote! { #crate_name };
133    if let Some(mod_path) = _mod_path {
134        // 解析 mod_path 并生成路径tokens
135        let path_segments: Vec<&str> = mod_path.split("::").collect();
136        let path_idents: Vec<proc_macro2::Ident> = path_segments
137            .iter()
138            .map(|segment| format_ident!("{}", segment))
139            .collect();
140        crate_path_tokens = quote! { #crate_name::#(#path_idents)::* };
141    }
142
143    let crate_name_version = format!("{}_{}", crate_name_str, prefix_version());
144
145    for item in &input.items {
146        if let syn::TraitItem::Fn(func) = item {
147            let fn_name = func.sig.ident.clone();
148            let extern_fn_name = extern_fn_name(&crate_name_version, &fn_name);
149
150            let attrs = &func.attrs;
151            let inputs = &func.sig.inputs;
152            let output = &func.sig.output;
153            let generics = &func.sig.generics;
154            let unsafety = &func.sig.unsafety;
155
156            let mut param_names = vec![];
157            let mut param_types = vec![];
158
159            for input in inputs {
160                if let syn::FnArg::Typed(pat_type) = input {
161                    param_names.push(&pat_type.pat);
162                    param_types.push(&pat_type.ty);
163                }
164            }
165
166            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
167
168            fn_list.push(quote! {
169                #(#attrs)*
170                pub #unsafety fn #fn_name #generics (#inputs) #output {
171                    unsafe extern #extern_abi {
172                        fn #extern_fn_name #generics (#inputs) #output;
173                    }
174                    unsafe{ #extern_fn_name(#(#param_names),*) }
175                }
176            });
177        } else {
178            bail!(
179                item.span(),
180                "Only function items are allowed in extern traits"
181            );
182        }
183    }
184
185    let _warn_fn_name = format_ident!(
186        "Trait_{}_in_crate_{}_{}_need_impl",
187        input.ident,
188        crate_name_str.replace("-", "_"),
189        prefix_version()
190    );
191
192    let generated_macro = if not_def_impl {
193        quote! {}
194    } else {
195        quote! {
196            /// Helper macro to implement the extern trait for a type.
197            pub use trait_ffi::impl_extern_trait;
198
199            /// Implement the extern trait for a type.
200            #[macro_export]
201            macro_rules! impl_trait {
202                (impl $trait:ident for $type:ty { $($body:tt)* }) => {
203                    #[#crate_path_tokens::impl_extern_trait(name = #crate_name_version, abi = #abi)]
204                    impl $trait for $type {
205                        $($body)*
206                    }
207
208                    // #[allow(snake_case)]
209                    // #[unsafe(no_mangle)]
210                    // extern "C" fn #warn_fn_name() { }
211                };
212            }
213        }
214    };
215
216    quote! {
217        #input
218
219        /// Module generated by `trait-ffi`.
220        #vis mod #mod_name {
221            use super::*;
222            /// `trait-ffi` generated.
223            // pub fn ____checker_do_not_use(){
224            //     unsafe extern "C" {
225            //         fn #warn_fn_name();
226            //     }
227            //     unsafe { #warn_fn_name() };
228            // }
229            #(#fn_list)*
230        }
231
232        #generated_macro
233    }
234    .into()
235}
236
237fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
238    if args.is_empty() {
239        return Err(
240            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
241                .to_string(),
242        );
243    }
244
245    let args_str = args.to_string();
246    let mut name = None;
247    let mut abi = None;
248
249    let parts: Vec<&str> = args_str.split(',').collect();
250
251    for part in parts {
252        let part = part.trim();
253        if part.starts_with("name") {
254            if let Some(start) = part.find('"')
255                && let Some(end) = part.rfind('"')
256                && start < end
257            {
258                name = Some(part[start + 1..end].to_string());
259            }
260        } else if part.starts_with("abi")
261            && let Some(start) = part.find('"')
262            && let Some(end) = part.rfind('"')
263            && start < end
264        {
265            abi = Some(part[start + 1..end].to_string());
266        }
267    }
268
269    let name = name.ok_or_else(|| {
270        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
271            .to_string()
272    })?;
273    let abi = abi.unwrap_or_else(|| "c".to_string());
274
275    if abi != "c" && abi != "rust" {
276        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
277    }
278
279    Ok((name, abi))
280}
281
282/// Implements an extern trait for a type and generates corresponding C function exports.
283///
284/// This macro takes a trait implementation and generates extern "C" functions that can be
285/// called from other languages. Each method in the trait implementation gets a corresponding
286/// extern function with a mangled name based on the crate name and version.
287///
288/// # Arguments
289/// - `name`: The name of the crate that defines the extern trait
290/// - `abi`: The ABI to use for the extern functions ("c" or "rust"), defaults to "c"
291///
292/// # Example
293/// ```rust
294/// struct Calculator;
295///
296/// #[impl_extern_trait(name = "calculator_crate", abi = "c")]
297/// impl MyTrait for Calculator {
298///     fn add(&self, a: i32, b: i32) -> i32 {
299///         a + b
300///     }
301/// }
302/// ```
303///
304/// This will generate extern "C" functions that can be called from other languages.
305#[proc_macro_attribute]
306pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
307    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
308        Ok((name, abi)) => (name, abi),
309        Err(error_msg) => {
310            bail!(Span::call_site(), error_msg);
311        }
312    };
313    let input = parse_macro_input!(input as ItemImpl);
314    let mut extern_fn_list = vec![];
315
316    let struct_name = input.self_ty.clone();
317    let trait_name = input.clone().trait_.unwrap().1;
318
319    for item in &input.items {
320        if let syn::ImplItem::Fn(func) = item {
321            let fn_name_raw = &func.sig.ident;
322            let fn_name = extern_fn_name(&crate_name_str, fn_name_raw);
323
324            let inputs = &func.sig.inputs;
325            let output = &func.sig.output;
326            let generics = &func.sig.generics;
327            let unsafety = &func.sig.unsafety;
328            // preserve attributes from the original impl method (e.g. #[cfg], docs)
329            let attrs = &func.attrs;
330
331            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
332
333            let mut param_names = vec![];
334            let mut param_types = vec![];
335
336            for input in inputs {
337                if let syn::FnArg::Typed(pat_type) = input {
338                    param_names.push(&pat_type.pat);
339                    param_types.push(&pat_type.ty);
340                }
341            }
342            let mut body = quote! {
343                <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
344            };
345
346            if unsafety.is_some() {
347                body = quote! { unsafe { #body } };
348            }
349            extern_fn_list.push(quote! {
350                #(#attrs)*
351                /// `trait-ffi` generated extern function.
352                #[unsafe(no_mangle)]
353                pub #unsafety extern #extern_abi fn #fn_name #generics (#inputs) #output {
354                    #body
355                }
356            });
357        }
358    }
359
360    quote! {
361        #input
362        #(#extern_fn_list)*
363    }
364    .into()
365}