trait_ffi/
lib.rs

1use convert_case::Casing;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{Ident, ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
6
7macro_rules! bail {
8    ($i:expr, $msg:expr) => {
9        return syn::parse::Error::new($i, $msg).to_compile_error().into();
10    };
11}
12
13fn get_crate_name() -> String {
14    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
15}
16fn get_crate_version() -> String {
17    std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string())
18}
19
20fn prefix_version() -> String {
21    let version = lenient_semver::parse(&get_crate_version()).unwrap();
22    let major = version.major;
23    let minor = version.minor;
24    if major == 0 {
25        format!("0_{minor}")
26    } else {
27        major.to_string()
28    }
29}
30
31fn extern_fn_name(crate_name: &str, fn_name: &Ident) -> Ident {
32    let crate_name = crate_name.to_lowercase().replace("-", "_");
33    let version = prefix_version();
34
35    format_ident!("__{crate_name}_{version}_{fn_name}")
36}
37
38fn parse_def_extern_trait_args(args: TokenStream) -> Result<String, String> {
39    if args.is_empty() {
40        return Ok("rust".to_string()); // 默认使用 Rust ABI
41    }
42
43    let args_str = args.to_string();
44    let mut abi = None;
45
46    // 简单解析 abi="value" 形式
47    let parts: Vec<&str> = args_str.split(',').collect();
48
49    for part in parts {
50        let part = part.trim();
51        if part.starts_with("abi") {
52            if let Some(start) = part.find('"') {
53                if let Some(end) = part.rfind('"') {
54                    if start < end {
55                        abi = Some(part[start + 1..end].to_string());
56                    }
57                }
58            }
59        }
60    }
61
62    let abi = abi.unwrap_or_else(|| "rust".to_string());
63
64    if abi != "c" && abi != "rust" {
65        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
66    }
67
68    Ok(abi)
69}
70
71#[proc_macro_attribute]
72pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
73    let abi = match parse_def_extern_trait_args(args) {
74        Ok(abi) => abi,
75        Err(error_msg) => {
76            bail!(Span::call_site(), error_msg);
77        }
78    };
79
80    let input = parse_macro_input!(input as ItemTrait);
81    let vis = input.vis.clone();
82    let mod_name = format_ident!(
83        "{}",
84        input.ident.to_string().to_case(convert_case::Case::Snake)
85    );
86    let crate_name_str = get_crate_name();
87
88    let mut fn_list = vec![];
89
90    for item in &input.items {
91        if let syn::TraitItem::Fn(func) = item {
92            let fn_name = func.sig.ident.clone();
93            let extern_fn_name = extern_fn_name(&crate_name_str, &fn_name);
94
95            let attrs = &func.attrs;
96            let inputs = &func.sig.inputs;
97            let output = &func.sig.output;
98
99            // 生成参数名和类型
100            let mut param_names = vec![];
101            let mut param_types = vec![];
102
103            for input in inputs {
104                if let syn::FnArg::Typed(pat_type) = input {
105                    param_names.push(&pat_type.pat);
106                    param_types.push(&pat_type.ty);
107                }
108            }
109
110            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
111
112            fn_list.push(quote! {
113                #(#attrs)*
114                pub fn #fn_name(#inputs) #output {
115                    unsafe extern #extern_abi {
116                        fn #extern_fn_name(#inputs) #output;
117                    }
118                    unsafe{ #extern_fn_name(#(#param_names),*) }
119                }
120            });
121        } else {
122            bail!(
123                item.span(),
124                "Only function items are allowed in extern traits"
125            );
126        }
127    }
128
129    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
130
131    let warn_fn_name = format_ident!(
132        "Trait_{}_in_crate_{}_{}_need_impl",
133        input.ident,
134        crate_name_str.replace("-", "_"),
135        prefix_version()
136    );
137
138    let generated_macro = quote! {
139        #[macro_export]
140        macro_rules! impl_trait {
141            (impl $trait:ident for $type:ty { $($body:tt)* }) => {
142                #[#crate_name::impl_extern_trait(name = #crate_name_str, abi = #abi)]
143                impl $trait for $type {
144                    $($body)*
145                }
146
147                #[allow(snake_case)]
148                #[unsafe(no_mangle)]
149                extern "C" fn #warn_fn_name() { }
150            };
151        }
152    };
153
154    quote! {
155        pub use trait_ffi::impl_extern_trait;
156
157        #input
158
159        #vis mod #mod_name {
160            use super::*;
161            pub fn ____checker_do_not_use(){
162                unsafe extern "C" {
163                    fn #warn_fn_name();
164                }
165                unsafe { #warn_fn_name() };
166            }
167            #(#fn_list)*
168        }
169
170        #generated_macro
171    }
172    .into()
173}
174
175fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
176    if args.is_empty() {
177        return Err(
178            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
179                .to_string(),
180        );
181    }
182
183    let args_str = args.to_string();
184    let mut name = None;
185    let mut abi = None;
186
187    // 简单解析 name="value", abi="value" 形式
188    let parts: Vec<&str> = args_str.split(',').collect();
189
190    for part in parts {
191        let part = part.trim();
192        if part.starts_with("name") {
193            if let Some(start) = part.find('"') {
194                if let Some(end) = part.rfind('"') {
195                    if start < end {
196                        name = Some(part[start + 1..end].to_string());
197                    }
198                }
199            }
200        } else if part.starts_with("abi") {
201            if let Some(start) = part.find('"') {
202                if let Some(end) = part.rfind('"') {
203                    if start < end {
204                        abi = Some(part[start + 1..end].to_string());
205                    }
206                }
207            }
208        }
209    }
210
211    let name = name.ok_or_else(|| {
212        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
213            .to_string()
214    })?;
215    let abi = abi.unwrap_or_else(|| "c".to_string());
216
217    if abi != "c" && abi != "rust" {
218        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
219    }
220
221    Ok((name, abi))
222}
223
224#[proc_macro_attribute]
225pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
226    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
227        Ok((name, abi)) => (name, abi),
228        Err(error_msg) => {
229            bail!(Span::call_site(), error_msg);
230        }
231    };
232    let input = parse_macro_input!(input as ItemImpl);
233    let mut extern_fn_list = vec![];
234
235    let struct_name = input.self_ty.clone();
236    let trait_name = input.clone().trait_.unwrap().1;
237
238    for item in &input.items {
239        if let syn::ImplItem::Fn(func) = item {
240            let fn_name_raw = &func.sig.ident;
241            let fn_name = extern_fn_name(&crate_name_str, fn_name_raw);
242
243            let inputs = &func.sig.inputs;
244            let output = &func.sig.output;
245
246            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
247
248            let mut param_names = vec![];
249            let mut param_types = vec![];
250
251            for input in inputs {
252                if let syn::FnArg::Typed(pat_type) = input {
253                    param_names.push(&pat_type.pat);
254                    param_types.push(&pat_type.ty);
255                }
256            }
257
258            extern_fn_list.push(quote! {
259                #[unsafe(no_mangle)]
260                pub extern #extern_abi fn #fn_name(#inputs) #output {
261                    <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
262                }
263            });
264        }
265    }
266
267    quote! {
268        #input
269        #(#extern_fn_list)*
270    }
271    .into()
272}