trait_ffi/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{Ident, ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
8
9macro_rules! bail {
10    ($i:expr, $msg:expr) => {
11        return syn::parse::Error::new($i, $msg).to_compile_error().into();
12    };
13}
14
15fn get_crate_name() -> String {
16    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
17}
18
19fn get_crate_version() -> String {
20    std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string())
21}
22
23fn prefix_version() -> String {
24    let version = lenient_semver::parse(&get_crate_version()).unwrap();
25    let major = version.major;
26    let minor = version.minor;
27    if major == 0 {
28        format!("0_{minor}")
29    } else {
30        major.to_string()
31    }
32}
33
34fn extern_fn_name(crate_name: &str, fn_name: &Ident) -> Ident {
35    let crate_name = crate_name.to_lowercase().replace("-", "_");
36    // let version = prefix_version();
37
38    format_ident!("__{crate_name}_{fn_name}")
39}
40
41fn parse_def_extern_trait_args(
42    args: TokenStream,
43) -> Result<(String, bool, Option<String>), String> {
44    if args.is_empty() {
45        return Ok(("rust".to_string(), false, None)); // 默认使用 Rust ABI,默认生成 impl_trait! 宏,无自定义模块路径
46    }
47
48    let args_str = args.to_string();
49    let mut abi = None;
50    let mut not_def_impl = false;
51    let mut mod_path = None;
52
53    // 简单解析 abi="value"、not_def_impl 和 mod_path="value" 形式
54    let parts: Vec<&str> = args_str.split(',').collect();
55
56    for part in parts {
57        let part = part.trim();
58        if part.starts_with("abi")
59            && let Some(start) = part.find('"')
60            && let Some(end) = part.rfind('"')
61            && start < end
62        {
63            abi = Some(part[start + 1..end].to_string());
64        } else if part.starts_with("mod_path")
65            && let Some(start) = part.find('"')
66            && let Some(end) = part.rfind('"')
67            && start < end
68        {
69            mod_path = Some(part[start + 1..end].to_string());
70        } else if part == "not_def_impl" {
71            not_def_impl = true;
72        }
73    }
74
75    let abi = abi.unwrap_or_else(|| "rust".to_string());
76
77    if abi != "c" && abi != "rust" {
78        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
79    }
80
81    Ok((abi, not_def_impl, mod_path))
82}
83
84/// Defines an extern trait that can be called across FFI boundaries.
85///
86/// This macro converts a regular Rust trait into a trait that can be called through FFI.
87/// It generates:
88/// 1. The original trait definition
89/// 2. A module containing wrapper functions that call external implementations
90/// 3. Optionally, a helper macro `impl_trait!` for implementing the trait (unless `not_def_impl` is specified)
91/// 4. A checker function to ensure the trait is properly implemented
92///
93/// # Arguments
94/// - `abi`: Optional parameter specifying ABI type ("c" or "rust"), defaults to "rust"
95/// - `not_def_impl`: Optional parameter to skip generating the `impl_trait!` macro
96///
97/// # Example
98/// ```rust
99/// #[def_extern_trait(abi = "c")]
100/// trait Calculator {
101///     fn add(&self, a: i32, b: i32) -> i32;
102///     fn multiply(&self, a: i32, b: i32) -> i32;
103/// }
104///
105/// // Skip generating impl_trait! macro
106/// #[def_extern_trait(abi = "c", not_def_impl)]
107/// trait Calculator2 {
108///     fn add(&self, a: i32, b: i32) -> i32;
109/// }
110/// ```
111///
112/// This will generate a `calculator` module containing functions that can call external implementations.
113#[proc_macro_attribute]
114pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
115    let (abi, not_def_impl, _mod_path) = match parse_def_extern_trait_args(args) {
116        Ok((abi, not_def_impl, mod_path)) => (abi, not_def_impl, mod_path),
117        Err(error_msg) => {
118            bail!(Span::call_site(), error_msg);
119        }
120    };
121
122    let input = parse_macro_input!(input as ItemTrait);
123    let vis = input.vis.clone();
124    let mod_name = format_ident!(
125        "{}",
126        input.ident.to_string().to_case(convert_case::Case::Snake)
127    );
128    let crate_name_str = get_crate_name();
129
130    let mut fn_list = vec![];
131    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
132    let mut crate_path_tokens = quote! { #crate_name };
133    if let Some(mod_path) = _mod_path {
134        // 解析 mod_path 并生成路径tokens
135        let path_segments: Vec<&str> = mod_path.split("::").collect();
136        let path_idents: Vec<proc_macro2::Ident> = path_segments
137            .iter()
138            .map(|segment| format_ident!("{}", segment))
139            .collect();
140        crate_path_tokens = quote! { #crate_name::#(#path_idents)::* };
141    }
142
143    let crate_name_version = format!("{}_{}", crate_name_str, prefix_version());
144
145    for item in &input.items {
146        if let syn::TraitItem::Fn(func) = item {
147            let fn_name = func.sig.ident.clone();
148            let extern_fn_name = extern_fn_name(&crate_name_version, &fn_name);
149
150            let attrs = &func.attrs;
151            let inputs = &func.sig.inputs;
152            let output = &func.sig.output;
153            let generics = &func.sig.generics;
154            let unsafety = &func.sig.unsafety;
155
156            let mut param_names = vec![];
157            let mut param_types = vec![];
158
159            for input in inputs {
160                if let syn::FnArg::Typed(pat_type) = input {
161                    param_names.push(&pat_type.pat);
162                    param_types.push(&pat_type.ty);
163                }
164            }
165
166            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
167
168            fn_list.push(quote! {
169                #(#attrs)*
170                pub #unsafety fn #fn_name #generics (#inputs) #output {
171                    unsafe extern #extern_abi {
172                        fn #extern_fn_name #generics (#inputs) #output;
173                    }
174                    unsafe{ #extern_fn_name(#(#param_names),*) }
175                }
176            });
177        } else {
178            bail!(
179                item.span(),
180                "Only function items are allowed in extern traits"
181            );
182        }
183    }
184
185    let _warn_fn_name = format_ident!(
186        "Trait_{}_in_crate_{}_{}_need_impl",
187        input.ident,
188        crate_name_str.replace("-", "_"),
189        prefix_version()
190    );
191
192    let generated_macro = if not_def_impl {
193        quote! {}
194    } else {
195        quote! {
196            /// Helper macro to implement the extern trait for a type.
197            pub use trait_ffi::impl_extern_trait;
198
199            #[macro_export]
200            macro_rules! impl_trait {
201                (impl $trait:ident for $type:ty { $($body:tt)* }) => {
202                    #[#crate_path_tokens::impl_extern_trait(name = #crate_name_version, abi = #abi)]
203                    impl $trait for $type {
204                        $($body)*
205                    }
206
207                    // #[allow(snake_case)]
208                    // #[unsafe(no_mangle)]
209                    // extern "C" fn #warn_fn_name() { }
210                };
211            }
212        }
213    };
214
215    quote! {
216        #input
217
218        /// Module generated by `trait-ffi`.
219        #vis mod #mod_name {
220            use super::*;
221            /// `trait-ffi` generated.
222            // pub fn ____checker_do_not_use(){
223            //     unsafe extern "C" {
224            //         fn #warn_fn_name();
225            //     }
226            //     unsafe { #warn_fn_name() };
227            // }
228            #(#fn_list)*
229        }
230
231        #generated_macro
232    }
233    .into()
234}
235
236fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
237    if args.is_empty() {
238        return Err(
239            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
240                .to_string(),
241        );
242    }
243
244    let args_str = args.to_string();
245    let mut name = None;
246    let mut abi = None;
247
248    let parts: Vec<&str> = args_str.split(',').collect();
249
250    for part in parts {
251        let part = part.trim();
252        if part.starts_with("name") {
253            if let Some(start) = part.find('"')
254                && let Some(end) = part.rfind('"')
255                && start < end
256            {
257                name = Some(part[start + 1..end].to_string());
258            }
259        } else if part.starts_with("abi")
260            && let Some(start) = part.find('"')
261            && let Some(end) = part.rfind('"')
262            && start < end
263        {
264            abi = Some(part[start + 1..end].to_string());
265        }
266    }
267
268    let name = name.ok_or_else(|| {
269        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
270            .to_string()
271    })?;
272    let abi = abi.unwrap_or_else(|| "c".to_string());
273
274    if abi != "c" && abi != "rust" {
275        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
276    }
277
278    Ok((name, abi))
279}
280
281/// Implements an extern trait for a type and generates corresponding C function exports.
282///
283/// This macro takes a trait implementation and generates extern "C" functions that can be
284/// called from other languages. Each method in the trait implementation gets a corresponding
285/// extern function with a mangled name based on the crate name and version.
286///
287/// # Arguments
288/// - `name`: The name of the crate that defines the extern trait
289/// - `abi`: The ABI to use for the extern functions ("c" or "rust"), defaults to "c"
290///
291/// # Example
292/// ```rust
293/// struct Calculator;
294///
295/// #[impl_extern_trait(name = "calculator_crate", abi = "c")]
296/// impl MyTrait for Calculator {
297///     fn add(&self, a: i32, b: i32) -> i32 {
298///         a + b
299///     }
300/// }
301/// ```
302///
303/// This will generate extern "C" functions that can be called from other languages.
304#[proc_macro_attribute]
305pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
306    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
307        Ok((name, abi)) => (name, abi),
308        Err(error_msg) => {
309            bail!(Span::call_site(), error_msg);
310        }
311    };
312    let input = parse_macro_input!(input as ItemImpl);
313    let mut extern_fn_list = vec![];
314
315    let struct_name = input.self_ty.clone();
316    let trait_name = input.clone().trait_.unwrap().1;
317
318    for item in &input.items {
319        if let syn::ImplItem::Fn(func) = item {
320            let fn_name_raw = &func.sig.ident;
321            let fn_name = extern_fn_name(&crate_name_str, fn_name_raw);
322
323            let inputs = &func.sig.inputs;
324            let output = &func.sig.output;
325            let generics = &func.sig.generics;
326            let unsafety = &func.sig.unsafety;
327            // preserve attributes from the original impl method (e.g. #[cfg], docs)
328            let attrs = &func.attrs;
329
330            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
331
332            let mut param_names = vec![];
333            let mut param_types = vec![];
334
335            for input in inputs {
336                if let syn::FnArg::Typed(pat_type) = input {
337                    param_names.push(&pat_type.pat);
338                    param_types.push(&pat_type.ty);
339                }
340            }
341            let mut body = quote! {
342                <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
343            };
344
345            if unsafety.is_some() {
346                body = quote! { unsafe { #body } };
347            }
348            extern_fn_list.push(quote! {
349                #(#attrs)*
350                /// `trait-ffi` generated extern function.
351                #[unsafe(no_mangle)]
352                pub #unsafety extern #extern_abi fn #fn_name #generics (#inputs) #output {
353                    #body
354                }
355            });
356        }
357    }
358
359    quote! {
360        #input
361        #(#extern_fn_list)*
362    }
363    .into()
364}