rust2go_common/
g2r.rs

1// Copyright 2024 ihciah. All Rights Reserved.
2
3use std::collections::HashMap;
4
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use syn::{Error, FnArg, Ident, ItemTrait, Meta, Pat, Result, ReturnType, TraitItem, Type};
8
9use crate::common::{Param, ParamType};
10
11pub struct G2RTraitRepr {
12    name: Ident,
13    fns: Vec<G2RFnRepr>,
14}
15
16pub struct G2RFnRepr {
17    name: Ident,
18    params: Vec<Param>,
19    ret: Option<ParamType>,
20    cgo_call: bool,
21}
22
23impl TryFrom<&ItemTrait> for G2RTraitRepr {
24    type Error = Error;
25
26    fn try_from(trat: &ItemTrait) -> Result<Self> {
27        let trait_name = trat.ident.clone();
28        let mut fns = Vec::new();
29
30        for item in trat.items.iter() {
31            let TraitItem::Fn(fn_item) = item else {
32                sbail!("only fn items are supported");
33            };
34            let fn_name = fn_item.sig.ident.clone();
35            let mut params = Vec::new();
36            for param in fn_item.sig.inputs.iter() {
37                let FnArg::Typed(param) = param else {
38                    sbail!("only typed fn args are supported")
39                };
40                // param name
41                let Pat::Ident(param_name) = param.pat.as_ref() else {
42                    sbail!("only ident fn args are supported");
43                };
44                // param type
45                let param_type = ParamType::try_from(param.ty.as_ref())?;
46                params.push(Param {
47                    name: param_name.ident.clone(),
48                    ty: param_type,
49                });
50            }
51            if fn_item.sig.asyncness.is_some() {
52                sbail!("async is not supported yet when go call rust, manually spawn by your own!");
53            }
54            let param_type = match &fn_item.sig.output {
55                ReturnType::Default => None,
56                ReturnType::Type(_, t) => match t.as_ref() {
57                    Type::Path(_) => {
58                        let param_type = ParamType::try_from(t.as_ref())?;
59                        Some(param_type)
60                    }
61                    _ => sbail!("only path type returns are supported"),
62                },
63            };
64            let ret = param_type;
65            let cgo_call = fn_item
66                .attrs
67                .iter()
68                .any(|attr|
69                    matches!(&attr.meta, Meta::Path(p) if p.get_ident() == Some(&format_ident!("cgo_call")) || p.get_ident() == Some(&format_ident!("cgo")))
70                );
71            fns.push(G2RFnRepr {
72                name: fn_name,
73                params,
74                ret,
75                cgo_call,
76            });
77        }
78
79        Ok(G2RTraitRepr {
80            name: trait_name,
81            fns,
82        })
83    }
84}
85
86macro_rules! or_empty {
87    ($flag: expr, $content: expr) => {
88        if $flag {
89            $content
90        } else {
91            ""
92        }
93    };
94}
95
96impl G2RTraitRepr {
97    pub fn fns(&self) -> &[G2RFnRepr] {
98        &self.fns
99    }
100
101    pub fn has_ret(&self) -> bool {
102        self.fns.iter().any(|f| f.ret.is_some())
103    }
104
105    pub fn to_importc(&self) -> String {
106        let prefix = format!("const void c_{}_", self.name);
107        let decs = self
108            .fns
109            .iter()
110            .map(|f| match f.ffi_param_cnt() {
111                0 => format!("{prefix}{}();\n", f.name),
112                1 => format!("{prefix}{}(const void*);\n", f.name),
113                _ => format!("{prefix}{}(const void*, const void*);\n", f.name),
114            })
115            .collect::<Vec<String>>();
116        decs.join("")
117    }
118
119    pub fn to_go(&self, levels: &HashMap<Ident, u8>) -> String {
120        let trait_name = &self.name;
121        let struct_name = format!("{trait_name}Impl");
122        let mut out = format!("type {struct_name} struct{{}}\n");
123
124        for f in &self.fns {
125            let call_type = if f.cgo_call { "cgocall" } else { "asmcall" };
126            let ffi_param_cnt = f.ffi_param_cnt();
127            let f_name = &f.name;
128
129            let params = f
130                .params
131                .iter()
132                .map(|p| format!("{} *{}", p.name, p.ty.to_go()))
133                .collect::<Vec<_>>()
134                .join(",");
135            let ret = f.ret.as_ref().map_or(String::new(), |ret| ret.to_go());
136            let init_slot = or_empty!(f.ret.is_some(), "_internal_slot := [2]unsafe.Pointer{}\n");
137            let mut init_params = String::new();
138            if !f.params.is_empty() {
139                init_params = format!(
140                    "_internal_params := [{}]unsafe.Pointer{{}}\n",
141                    f.params.len()
142                );
143            }
144
145            // write function header
146            out.push_str(&format!(
147                "func ({struct_name}) {f_name}({params}) {ret} {{
148                    {init_slot}{init_params}"
149            ));
150
151            // convert params
152            for (i, p) in f.params.iter().enumerate() {
153                // user_ref, user_buffer := cvt_ref(cntDemoUser, refDemoUser)(user)
154                // _internal_params[0] = unsafe.Pointer(&user_ref)
155                let cnt = p.ty.go_to_c_field_counter(levels).0;
156                let ref_ = p.ty.go_to_c_field_converter(levels).0;
157                out.push_str(&format!(
158                    "{pname}_ref, {pname}_buffer := cvt_ref({cnt}, {ref_})({pname})
159                    _internal_params[{i}] = unsafe.Pointer(&{pname}_ref)
160                    ",
161                    pname = p.name,
162                ));
163            }
164
165            // call
166            let mut call_params = String::new();
167            // unsafe.Pointer(&_internal_slot), unsafe.Pointer(&_internal_params)
168            if f.ret.is_some() {
169                call_params.push_str(", unsafe.Pointer(&_internal_slot)");
170            }
171            if !f.params.is_empty() {
172                call_params.push_str(", unsafe.Pointer(&_internal_params)");
173            }
174            out.push_str(&format!(
175                "{call_type}.CallFuncG0P{ffi_param_cnt}(unsafe.Pointer(C.c_{trait_name}_{f_name}){call_params})\n"
176            ));
177
178            // keepalive
179            if f.ret.is_some() {
180                out.push_str("runtime.KeepAlive(_internal_slot)\n");
181            }
182            if !f.params.is_empty() {
183                out.push_str("runtime.KeepAlive(_internal_params)\n");
184            }
185            for p in f.params.iter() {
186                out.push_str(&format!("runtime.KeepAlive({}_buffer)\n", p.name));
187            }
188
189            if let Some(r) = &f.ret {
190                // val := ownString(*(*C.StringRef)(_internal_slot[0]))
191                // asmcall.CallFuncG0P1(unsafe.Pointer(C.c_rust2go_internal_drop), unsafe.Pointer(_internal_slot[1]))
192                // return val
193                let cvt = r.c_to_go_field_converter_owned();
194                let cty = r.to_c(false);
195                out.push_str(&format!("val := {cvt}(*(*C.{cty})(_internal_slot[0]))
196                {call_type}.CallFuncG0P1(unsafe.Pointer(C.c_rust2go_internal_drop), unsafe.Pointer(_internal_slot[1]))
197                return val
198                "));
199            }
200
201            out.push_str("}\n");
202        }
203
204        out
205    }
206
207    // Generate rust impl.
208    pub fn generate_rs(&self) -> Result<TokenStream> {
209        let trait_name = &self.name;
210        let mut fn_entries = Vec::with_capacity(self.fns.len());
211        for f in self.fns.iter() {
212            let f_name = &f.name;
213            let cf_name = format_ident!("c_{}_{}", &self.name, &f.name);
214            let slot_expr = f
215                .ret
216                .as_ref()
217                .map(|_| quote! {_internal_slot: *mut [*const (); 2],});
218            let mut params_expr = None;
219            if !f.params.is_empty() {
220                params_expr = Some(quote! {_internal_params: *const *const ()});
221            }
222            let mut params = Vec::new();
223            let mut param_names = Vec::new();
224            for (i, p) in f.params.iter().enumerate() {
225                let p_name = &p.name;
226                let i = i as isize;
227                params.push(quote! {
228                    let #p_name = _internal_params.offset(#i).read() as *const _;
229                    let #p_name = ::rust2go::FromRef::from_ref(unsafe { &*#p_name });
230                });
231                param_names.push(p.name.clone());
232            }
233
234            let bottom = if f.ret.is_some() {
235                quote! {
236                    let _internal_out = <Self as #trait_name>::#f_name(#(#param_names),*);
237                    let (_internal_buf, _internal_out_ref) = ::rust2go::ToRef::calc_ref(&_internal_out);
238
239                    let _internal_boxed_storage = ::std::boxed::Box::new((_internal_out, _internal_out_ref, _internal_buf));
240                    let ret_ptr = &_internal_boxed_storage.as_ref().1 as *const _ as *const ();
241                    let drop_ptr = ::std::boxed::Box::leak(_internal_boxed_storage as ::std::boxed::Box<dyn ::std::any::Any>) as *mut dyn ::std::any::Any as *mut ();
242
243                    *_internal_slot = [ret_ptr, drop_ptr];
244                }
245            } else {
246                quote! {
247                    <Self as #trait_name>::#f_name(#(#param_names),*);
248                }
249            };
250
251            fn_entries.push(quote! {
252                #[no_mangle]
253                unsafe extern "C" fn #cf_name(#slot_expr #params_expr) {
254                    #(#params)*
255                    #bottom
256                }
257            });
258        }
259
260        let impl_struct_name = format_ident!("{}Impl", trait_name);
261
262        Ok(quote! {
263            pub struct #impl_struct_name;
264            impl #impl_struct_name {
265                #(#fn_entries)*
266            }
267        })
268    }
269}
270
271impl G2RFnRepr {
272    fn ffi_param_cnt(&self) -> u8 {
273        [self.params.is_empty(), self.ret.is_none()]
274            .into_iter()
275            .filter(|x| !*x)
276            .count() as u8
277    }
278
279    pub const fn cgo_call(&self) -> bool {
280        self.cgo_call
281    }
282}