trait_ffi/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{Ident, ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
8
9macro_rules! bail {
10    ($i:expr, $msg:expr) => {
11        return syn::parse::Error::new($i, $msg).to_compile_error().into();
12    };
13}
14
15fn get_crate_name() -> String {
16    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
17}
18
19fn get_crate_version() -> String {
20    std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string())
21}
22
23fn prefix_version() -> String {
24    let version = lenient_semver::parse(&get_crate_version()).unwrap();
25    let major = version.major;
26    let minor = version.minor;
27    if major == 0 {
28        format!("0_{minor}")
29    } else {
30        major.to_string()
31    }
32}
33
34fn extern_fn_name(crate_name: &str, fn_name: &Ident) -> Ident {
35    let crate_name = crate_name.to_lowercase().replace("-", "_");
36    // let version = prefix_version();
37
38    format_ident!("__{crate_name}_{fn_name}")
39}
40
41fn parse_def_extern_trait_args(
42    args: TokenStream,
43) -> Result<(String, bool, Option<String>), String> {
44    if args.is_empty() {
45        return Ok(("rust".to_string(), false, None)); // 默认使用 Rust ABI,默认生成 impl_trait! 宏,无自定义模块路径
46    }
47
48    let args_str = args.to_string();
49    let mut abi = None;
50    let mut not_def_impl = false;
51    let mut mod_path = None;
52
53    // 简单解析 abi="value"、not_def_impl 和 mod_path="value" 形式
54    let parts: Vec<&str> = args_str.split(',').collect();
55
56    for part in parts {
57        let part = part.trim();
58        if part.starts_with("abi")
59            && let Some(start) = part.find('"')
60            && let Some(end) = part.rfind('"')
61            && start < end
62        {
63            abi = Some(part[start + 1..end].to_string());
64        } else if part.starts_with("mod_path")
65            && let Some(start) = part.find('"')
66            && let Some(end) = part.rfind('"')
67            && start < end
68        {
69            mod_path = Some(part[start + 1..end].to_string());
70        } else if part == "not_def_impl" {
71            not_def_impl = true;
72        }
73    }
74
75    let abi = abi.unwrap_or_else(|| "rust".to_string());
76
77    if abi != "c" && abi != "rust" {
78        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
79    }
80
81    Ok((abi, not_def_impl, mod_path))
82}
83
84/// Defines an extern trait that can be called across FFI boundaries.
85///
86/// This macro converts a regular Rust trait into a trait that can be called through FFI.
87/// It generates:
88/// 1. The original trait definition
89/// 2. A module containing wrapper functions that call external implementations
90/// 3. Optionally, a helper macro `impl_trait!` for implementing the trait (unless `not_def_impl` is specified)
91/// 4. A checker function to ensure the trait is properly implemented
92///
93/// # Arguments
94/// - `abi`: Optional parameter specifying ABI type ("c" or "rust"), defaults to "rust"
95/// - `not_def_impl`: Optional parameter to skip generating the `impl_trait!` macro
96///
97/// # Example
98/// ```rust
99/// #[def_extern_trait(abi = "c")]
100/// trait Calculator {
101///     fn add(&self, a: i32, b: i32) -> i32;
102///     fn multiply(&self, a: i32, b: i32) -> i32;
103/// }
104///
105/// // Skip generating impl_trait! macro
106/// #[def_extern_trait(abi = "c", not_def_impl)]
107/// trait Calculator2 {
108///     fn add(&self, a: i32, b: i32) -> i32;
109/// }
110/// ```
111///
112/// This will generate a `calculator` module containing functions that can call external implementations.
113#[proc_macro_attribute]
114pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
115    let (abi, not_def_impl, _mod_path) = match parse_def_extern_trait_args(args) {
116        Ok((abi, not_def_impl, mod_path)) => (abi, not_def_impl, mod_path),
117        Err(error_msg) => {
118            bail!(Span::call_site(), error_msg);
119        }
120    };
121
122    let input = parse_macro_input!(input as ItemTrait);
123    let vis = input.vis.clone();
124    let mod_name = format_ident!(
125        "{}",
126        input.ident.to_string().to_case(convert_case::Case::Snake)
127    );
128    let crate_name_str = get_crate_name();
129
130    let mut fn_list = vec![];
131    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
132    let mut crate_path_tokens = quote! { #crate_name };
133    if let Some(mod_path) = _mod_path {
134        // 解析 mod_path 并生成路径tokens
135        let path_segments: Vec<&str> = mod_path.split("::").collect();
136        let path_idents: Vec<proc_macro2::Ident> = path_segments
137            .iter()
138            .map(|segment| format_ident!("{}", segment))
139            .collect();
140        crate_path_tokens = quote! { #crate_name::#(#path_idents)::* };
141    }
142
143    let crate_name_version = format!("{}_{}", crate_name_str, prefix_version());
144
145    for item in &input.items {
146        if let syn::TraitItem::Fn(func) = item {
147            let fn_name = func.sig.ident.clone();
148            let extern_fn_name = extern_fn_name(&crate_name_version, &fn_name);
149
150            let attrs = &func.attrs;
151            let inputs = &func.sig.inputs;
152            let output = &func.sig.output;
153            let generics = &func.sig.generics;
154            let unsafety = &func.sig.unsafety;
155
156            let mut param_names = vec![];
157            let mut param_types = vec![];
158
159            for input in inputs {
160                if let syn::FnArg::Typed(pat_type) = input {
161                    param_names.push(&pat_type.pat);
162                    param_types.push(&pat_type.ty);
163                }
164            }
165
166            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
167
168            fn_list.push(quote! {
169                #(#attrs)*
170                pub #unsafety fn #fn_name #generics (#inputs) #output {
171                    unsafe extern #extern_abi {
172                        fn #extern_fn_name #generics (#inputs) #output;
173                    }
174                    unsafe{ #extern_fn_name(#(#param_names),*) }
175                }
176            });
177        } else {
178            bail!(
179                item.span(),
180                "Only function items are allowed in extern traits"
181            );
182        }
183    }
184
185    let _warn_fn_name = format_ident!(
186        "Trait_{}_in_crate_{}_{}_need_impl",
187        input.ident,
188        crate_name_str.replace("-", "_"),
189        prefix_version()
190    );
191
192    let generated_macro = if not_def_impl {
193        quote! {}
194    } else {
195        quote! {
196            pub use trait_ffi::impl_extern_trait;
197
198            #[macro_export]
199            macro_rules! impl_trait {
200                (impl $trait:ident for $type:ty { $($body:tt)* }) => {
201                    #[#crate_path_tokens::impl_extern_trait(name = #crate_name_version, abi = #abi)]
202                    impl $trait for $type {
203                        $($body)*
204                    }
205
206                    // #[allow(snake_case)]
207                    // #[unsafe(no_mangle)]
208                    // extern "C" fn #warn_fn_name() { }
209                };
210            }
211        }
212    };
213
214    quote! {
215        #input
216
217        #vis mod #mod_name {
218            use super::*;
219            /// `trait-ffi` generated.
220            // pub fn ____checker_do_not_use(){
221            //     unsafe extern "C" {
222            //         fn #warn_fn_name();
223            //     }
224            //     unsafe { #warn_fn_name() };
225            // }
226            #(#fn_list)*
227        }
228
229        #generated_macro
230    }
231    .into()
232}
233
234fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
235    if args.is_empty() {
236        return Err(
237            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
238                .to_string(),
239        );
240    }
241
242    let args_str = args.to_string();
243    let mut name = None;
244    let mut abi = None;
245
246    let parts: Vec<&str> = args_str.split(',').collect();
247
248    for part in parts {
249        let part = part.trim();
250        if part.starts_with("name") {
251            if let Some(start) = part.find('"')
252                && let Some(end) = part.rfind('"')
253                && start < end
254            {
255                name = Some(part[start + 1..end].to_string());
256            }
257        } else if part.starts_with("abi")
258            && let Some(start) = part.find('"')
259            && let Some(end) = part.rfind('"')
260            && start < end
261        {
262            abi = Some(part[start + 1..end].to_string());
263        }
264    }
265
266    let name = name.ok_or_else(|| {
267        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
268            .to_string()
269    })?;
270    let abi = abi.unwrap_or_else(|| "c".to_string());
271
272    if abi != "c" && abi != "rust" {
273        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
274    }
275
276    Ok((name, abi))
277}
278
279/// Implements an extern trait for a type and generates corresponding C function exports.
280///
281/// This macro takes a trait implementation and generates extern "C" functions that can be
282/// called from other languages. Each method in the trait implementation gets a corresponding
283/// extern function with a mangled name based on the crate name and version.
284///
285/// # Arguments
286/// - `name`: The name of the crate that defines the extern trait
287/// - `abi`: The ABI to use for the extern functions ("c" or "rust"), defaults to "c"
288///
289/// # Example
290/// ```rust
291/// struct Calculator;
292///
293/// #[impl_extern_trait(name = "calculator_crate", abi = "c")]
294/// impl MyTrait for Calculator {
295///     fn add(&self, a: i32, b: i32) -> i32 {
296///         a + b
297///     }
298/// }
299/// ```
300///
301/// This will generate extern "C" functions that can be called from other languages.
302#[proc_macro_attribute]
303pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
304    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
305        Ok((name, abi)) => (name, abi),
306        Err(error_msg) => {
307            bail!(Span::call_site(), error_msg);
308        }
309    };
310    let input = parse_macro_input!(input as ItemImpl);
311    let mut extern_fn_list = vec![];
312
313    let struct_name = input.self_ty.clone();
314    let trait_name = input.clone().trait_.unwrap().1;
315
316    for item in &input.items {
317        if let syn::ImplItem::Fn(func) = item {
318            let fn_name_raw = &func.sig.ident;
319            let fn_name = extern_fn_name(&crate_name_str, fn_name_raw);
320
321            let inputs = &func.sig.inputs;
322            let output = &func.sig.output;
323            let generics = &func.sig.generics;
324            let unsafety = &func.sig.unsafety;
325            // preserve attributes from the original impl method (e.g. #[cfg], docs)
326            let attrs = &func.attrs;
327
328            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
329
330            let mut param_names = vec![];
331            let mut param_types = vec![];
332
333            for input in inputs {
334                if let syn::FnArg::Typed(pat_type) = input {
335                    param_names.push(&pat_type.pat);
336                    param_types.push(&pat_type.ty);
337                }
338            }
339            let mut body = quote! {
340                <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
341            };
342
343            if unsafety.is_some() {
344                body = quote! { unsafe { #body } };
345            }
346            extern_fn_list.push(quote! {
347                #(#attrs)*
348                /// `trait-ffi` generated extern function.
349                #[unsafe(no_mangle)]
350                pub #unsafety extern #extern_abi fn #fn_name #generics (#inputs) #output {
351                    #body
352                }
353            });
354        }
355    }
356
357    quote! {
358        #input
359        #(#extern_fn_list)*
360    }
361    .into()
362}