rust2go_common/
common.rs

1// Copyright 2024 ihciah. All Rights Reserved.
2
3use proc_macro2::{Span, TokenStream};
4use quote::{format_ident, quote, ToTokens};
5use std::collections::HashMap;
6use syn::{Error, File, Ident, Item, PathSegment, Result, Type};
7
8use crate::{g2r::G2RTraitRepr, r2g::R2GTraitRepr};
9
10pub struct RawRsFile {
11    file: File,
12}
13
14impl RawRsFile {
15    pub fn new<S: AsRef<str>>(src: S) -> Self {
16        let src = src.as_ref();
17        let syntax = syn::parse_file(src).expect("Unable to parse file");
18        RawRsFile { file: syntax }
19    }
20
21    pub fn go_internal_drop() -> &'static str {
22        r#"
23const void c_rust2go_internal_drop(void*);
24"#
25    }
26
27    pub fn go_shm_include() -> &'static str {
28        r#"
29typedef struct QueueMeta {
30    uintptr_t buffer_ptr;
31    uintptr_t buffer_len;
32    uintptr_t head_ptr;
33    uintptr_t tail_ptr;
34    uintptr_t working_ptr;
35    uintptr_t stuck_ptr;
36    int32_t working_fd;
37    int32_t unstuck_fd;
38    } QueueMeta;
39"#
40    }
41
42    pub fn go_shm_ring_init() -> &'static str {
43        r#"
44        func ringsInit(crr, crw C.QueueMeta, fns []func(unsafe.Pointer, *ants.MultiPool, func(interface{}, []byte, uint))) {
45            const MULTIPOOL_SIZE = 8
46            const SIZE_PER_POOL = -1
47
48            type Storage struct {
49                resp   interface{}
50                buffer []byte
51            }
52
53            type Payload struct {
54                Ptr          uint
55                UserData     uint
56                NextUserData uint
57                CallId       uint32
58                Flag         uint32
59            }
60
61            const CALL = 0b0101
62            const REPLY = 0b1110
63            const DROP = 0b1000
64
65            queueMetaCvt := func(cq C.QueueMeta) mem_ring.QueueMeta {
66                return mem_ring.QueueMeta{
67                    BufferPtr:  uintptr(cq.buffer_ptr),
68                    BufferLen:  uintptr(cq.buffer_len),
69                    HeadPtr:    uintptr(cq.head_ptr),
70                    TailPtr:    uintptr(cq.tail_ptr),
71                    WorkingPtr: uintptr(cq.working_ptr),
72                    StuckPtr:   uintptr(cq.stuck_ptr),
73                    WorkingFd:  int32(cq.working_fd),
74                    UnstuckFd:  int32(cq.unstuck_fd),
75                }
76            }
77
78            rr := queueMetaCvt(crr)
79            rw := queueMetaCvt(crw)
80
81            rrq := mem_ring.NewQueue[Payload](rr)
82            rwq := mem_ring.NewQueue[Payload](rw)
83
84            gr := rwq.Read()
85            gw := rrq.Write()
86
87            slab := mem_ring.NewMultiSlab[Storage]()
88            pool, _ := ants.NewMultiPool(MULTIPOOL_SIZE, SIZE_PER_POOL, ants.RoundRobin)
89
90            gr.RunHandler(func(p Payload) {
91                if p.Flag == CALL {
92                    post_func := func(resp interface{}, buffer []byte, offset uint) {
93                        if resp == nil {
94                            payload := Payload{
95                                Ptr:          0,
96                                UserData:     p.UserData,
97                                NextUserData: 0,
98                                CallId:       p.CallId,
99                                Flag:         DROP,
100                            }
101                            gw.Push(payload)
102                            return
103                        }
104
105                        // Use slab to hold reference of resp and buffer
106                        sid := slab.Push(Storage{
107                            resp,
108                            buffer,
109                        })
110                        payload := Payload{
111                            Ptr:          uint(uintptr(unsafe.Pointer(&buffer[offset]))),
112                            UserData:     p.UserData,
113                            NextUserData: sid,
114                            CallId:       p.CallId,
115                            Flag:         REPLY,
116                        }
117                        gw.Push(payload)
118                    }
119                    fns[p.CallId](unsafe.Pointer(uintptr(p.Ptr)), pool, post_func)
120                } else if p.Flag == DROP {
121                    // drop memory instantly
122                    slab.Pop(p.UserData)
123                }
124            })
125        }
126        "#
127    }
128
129    // The returned mapping is struct OriginalType -> RefType.
130    pub fn convert_structs_to_ref(&self) -> Result<(HashMap<Ident, Ident>, TokenStream)> {
131        let mut name_mapping = HashMap::new();
132
133        // Add these to generated code to make golang have C structs of string.
134        let mut out = quote! {
135            #[repr(C)]
136            pub struct StringRef {
137                pub ptr: *const u8,
138                pub len: usize,
139            }
140            #[repr(C)]
141            pub struct ListRef {
142                pub ptr: *const (),
143                pub len: usize,
144            }
145        };
146        name_mapping.insert(
147            Ident::new("String", Span::call_site()),
148            Ident::new("StringRef", Span::call_site()),
149        );
150        name_mapping.insert(
151            Ident::new("Vec", Span::call_site()),
152            Ident::new("ListRef", Span::call_site()),
153        );
154
155        for item in self.file.items.iter() {
156            match item {
157                // for example, convert
158                // pub struct DemoRequest {
159                //     pub name: String,
160                //     pub age: u8,
161                // }
162                // to
163                // #[repr(C)]
164                // pub struct DemoRequestRef {
165                //    pub name: StringRef,
166                //    pub age: u8,
167                // }
168                Item::Struct(s) => {
169                    let struct_name = s.ident.clone();
170                    let struct_name_ref = format_ident!("{}Ref", struct_name);
171                    name_mapping.insert(struct_name, struct_name_ref.clone());
172                    let mut field_names = Vec::with_capacity(s.fields.len());
173                    let mut field_types = Vec::with_capacity(s.fields.len());
174                    for field in s.fields.iter() {
175                        let field_name = field
176                            .clone()
177                            .ident
178                            .ok_or_else(|| serr!("only named fields are supported"))?;
179                        let field_type = ParamType::try_from(&field.ty)?;
180                        field_names.push(field_name);
181                        field_types.push(field_type.to_rust_ref(None));
182                    }
183                    out.extend(quote! {
184                        #[repr(C)]
185                        pub struct #struct_name_ref {
186                            #(pub #field_names: #field_types,)*
187                        }
188                    });
189                }
190                _ => continue,
191            }
192        }
193        Ok((name_mapping, out))
194    }
195
196    // go structs define and newStruct/refStruct function impl.
197    pub fn convert_structs_to_go(
198        &self,
199        levels: &HashMap<Ident, u8>,
200        go118: bool,
201    ) -> Result<String> {
202        const GO118CODE: &str = r#"
203        // An alternative impl of unsafe.String for go1.18
204        func unsafeString(ptr *byte, length int) string {
205            sliceHeader := &reflect.SliceHeader{
206                Data: uintptr(unsafe.Pointer(ptr)),
207                Len:  length,
208                Cap:  length,
209            }
210            return *(*string)(unsafe.Pointer(sliceHeader))
211        }
212
213        // An alternative impl of unsafe.StringData for go1.18
214        func unsafeStringData(s string) *byte {
215            return (*byte)(unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&s)).Data))
216        }
217        func newString(s_ref C.StringRef) string {
218            return unsafeString((*byte)(unsafe.Pointer(s_ref.ptr)), int(s_ref.len))
219        }
220        func refString(s *string, _ *[]byte) C.StringRef {
221            return C.StringRef{
222                ptr: (*C.uint8_t)(unsafeStringData(*s)),
223                len: C.uintptr_t(len(*s)),
224            }
225        }
226        "#;
227
228        const GO121CODE: &str = r#"
229        func newString(s_ref C.StringRef) string {
230            return unsafe.String((*byte)(unsafe.Pointer(s_ref.ptr)), s_ref.len)
231        }
232        func refString(s *string, _ *[]byte) C.StringRef {
233            return C.StringRef{
234                ptr: (*C.uint8_t)(unsafe.StringData(*s)),
235                len: C.uintptr_t(len(*s)),
236            }
237        }
238        "#;
239
240        let mut out = if go118 {
241            GO118CODE.to_string()
242        } else {
243            GO121CODE.to_string()
244        } + r#"
245        func ownString(s_ref C.StringRef) string {
246            return string(unsafe.Slice((*byte)(unsafe.Pointer(s_ref.ptr)), int(s_ref.len)))
247        }
248        func cntString(_ *string, _ *uint) [0]C.StringRef { return [0]C.StringRef{} }
249        func new_list_mapper[T1, T2 any](f func(T1) T2) func(C.ListRef) []T2 {
250            return func(x C.ListRef) []T2 {
251                input := unsafe.Slice((*T1)(unsafe.Pointer(x.ptr)), x.len)
252                output := make([]T2, len(input))
253                for i, v := range input {
254                    output[i] = f(v)
255                }
256                return output
257            }
258        }
259        func new_list_mapper_primitive[T1, T2 any](_ func(T1) T2) func(C.ListRef) []T2 {
260            return func(x C.ListRef) []T2 {
261                return unsafe.Slice((*T2)(unsafe.Pointer(x.ptr)), x.len)
262            }
263        }
264        // only handle non-primitive type T
265        func cnt_list_mapper[T, R any](f func(s *T, cnt *uint)[0]R) func(s *[]T, cnt *uint) [0]C.ListRef {
266            return func(s *[]T, cnt *uint) [0]C.ListRef {
267                for _, v := range *s {
268                    f(&v, cnt)
269                }
270                *cnt += uint(len(*s)) * size_of[R]()
271                return [0]C.ListRef{}
272            }
273        }
274
275        // only handle primitive type T
276        func cnt_list_mapper_primitive[T, R any](_ func(s *T, cnt *uint)[0]R) func(s *[]T, cnt *uint) [0]C.ListRef {
277            return func(s *[]T, cnt *uint) [0]C.ListRef {return [0]C.ListRef{}}
278        }
279        // only handle non-primitive type T
280        func ref_list_mapper[T, R any](f func(s *T, buffer *[]byte) R) func(s *[]T, buffer *[]byte) C.ListRef {
281            return func(s *[]T, buffer *[]byte) C.ListRef {
282                if len(*buffer) == 0 {
283                    return C.ListRef{
284                        ptr: unsafe.Pointer(nil),
285                        len: C.uintptr_t(len(*s)),
286                    }
287                }
288                ret := C.ListRef{
289                    ptr: unsafe.Pointer(&(*buffer)[0]),
290                    len: C.uintptr_t(len(*s)),
291                }
292                children_bytes := int(size_of[R]()) * len(*s)
293                children := (*buffer)[:children_bytes]
294                *buffer = (*buffer)[children_bytes:]
295                for _, v := range *s {
296                    child := f(&v, buffer)
297                    len := unsafe.Sizeof(child)
298                    copy(children, unsafe.Slice((*byte)(unsafe.Pointer(&child)), len))
299                    children = children[len:]
300                }
301                return ret
302            }
303        }
304        // only handle primitive type T
305        func ref_list_mapper_primitive[T, R any](_ func(s *T, buffer *[]byte) R) func(s *[]T, buffer *[]byte) C.ListRef {
306            return func(s *[]T, buffer *[]byte) C.ListRef {
307                if len(*s) == 0 {
308                    return C.ListRef{
309                        ptr: unsafe.Pointer(nil),
310                        len: C.uintptr_t(0),
311                    }
312                }
313                return C.ListRef{
314                    ptr: unsafe.Pointer(&(*s)[0]),
315                    len: C.uintptr_t(len(*s)),
316                }
317            }
318        }
319        func size_of[T any]() uint {
320            var t T
321            return uint(unsafe.Sizeof(t))
322        }
323        func cvt_ref[R, CR any](cnt_f func(s *R, cnt *uint) [0]CR, ref_f func(p *R, buffer *[]byte) CR) func(p *R) (CR, []byte) {
324            return func(p *R) (CR, []byte) {
325                var cnt uint
326                cnt_f(p, &cnt)
327                buffer := make([]byte, cnt)
328                return ref_f(p, &buffer), buffer
329            }
330        }
331        func cvt_ref_cap[R, CR any](cnt_f func(s *R, cnt *uint) [0]CR, ref_f func(p *R, buffer *[]byte) CR, add_cap uint) func(p *R) (CR, []byte) {
332            return func(p *R) (CR, []byte) {
333                var cnt uint
334                cnt_f(p, &cnt)
335                buffer := make([]byte, cnt, cnt + add_cap)
336                return ref_f(p, &buffer), buffer
337            }
338        }
339
340        func newC_uint8_t(n C.uint8_t) uint8    { return uint8(n) }
341        func newC_uint16_t(n C.uint16_t) uint16 { return uint16(n) }
342        func newC_uint32_t(n C.uint32_t) uint32 { return uint32(n) }
343        func newC_uint64_t(n C.uint64_t) uint64 { return uint64(n) }
344        func newC_int8_t(n C.int8_t) int8       { return int8(n) }
345        func newC_int16_t(n C.int16_t) int16    { return int16(n) }
346        func newC_int32_t(n C.int32_t) int32    { return int32(n) }
347        func newC_int64_t(n C.int64_t) int64    { return int64(n) }
348        func newC_bool(n C.bool) bool           { return bool(n) }
349        func newC_uintptr_t(n C.uintptr_t) uint { return uint(n) }
350        func newC_intptr_t(n C.intptr_t) int    { return int(n) }
351        func newC_float(n C.float) float32      { return float32(n) }
352        func newC_double(n C.double) float64    { return float64(n) }
353
354        func cntC_uint8_t(_ *uint8, _ *uint) [0]C.uint8_t    { return [0]C.uint8_t{} }
355        func cntC_uint16_t(_ *uint16, _ *uint) [0]C.uint16_t { return [0]C.uint16_t{} }
356        func cntC_uint32_t(_ *uint32, _ *uint) [0]C.uint32_t { return [0]C.uint32_t{} }
357        func cntC_uint64_t(_ *uint64, _ *uint) [0]C.uint64_t { return [0]C.uint64_t{} }
358        func cntC_int8_t(_ *int8, _ *uint) [0]C.int8_t       { return [0]C.int8_t{} }
359        func cntC_int16_t(_ *int16, _ *uint) [0]C.int16_t    { return [0]C.int16_t{} }
360        func cntC_int32_t(_ *int32, _ *uint) [0]C.int32_t    { return [0]C.int32_t{} }
361        func cntC_int64_t(_ *int64, _ *uint) [0]C.int64_t    { return [0]C.int64_t{} }
362        func cntC_bool(_ *bool, _ *uint) [0]C.bool           { return [0]C.bool{} }
363        func cntC_uintptr_t(_ *uint, _ *uint) [0]C.uintptr_t { return [0]C.uintptr_t{} }
364        func cntC_intptr_t(_ *int, _ *uint) [0]C.intptr_t    { return [0]C.intptr_t{} }
365        func cntC_float(_ *float32, _ *uint) [0]C.float      { return [0]C.float{} }
366        func cntC_double(_ *float64, _ *uint) [0]C.double    { return [0]C.double{} }
367
368        func refC_uint8_t(p *uint8, _ *[]byte) C.uint8_t    { return C.uint8_t(*p) }
369        func refC_uint16_t(p *uint16, _ *[]byte) C.uint16_t { return C.uint16_t(*p) }
370        func refC_uint32_t(p *uint32, _ *[]byte) C.uint32_t { return C.uint32_t(*p) }
371        func refC_uint64_t(p *uint64, _ *[]byte) C.uint64_t { return C.uint64_t(*p) }
372        func refC_int8_t(p *int8, _ *[]byte) C.int8_t       { return C.int8_t(*p) }
373        func refC_int16_t(p *int16, _ *[]byte) C.int16_t    { return C.int16_t(*p) }
374        func refC_int32_t(p *int32, _ *[]byte) C.int32_t    { return C.int32_t(*p) }
375        func refC_int64_t(p *int64, _ *[]byte) C.int64_t    { return C.int64_t(*p) }
376        func refC_bool(p *bool, _ *[]byte) C.bool           { return C.bool(*p) }
377        func refC_uintptr_t(p *uint, _ *[]byte) C.uintptr_t { return C.uintptr_t(*p) }
378        func refC_intptr_t(p *int, _ *[]byte) C.intptr_t    { return C.intptr_t(*p) }
379        func refC_float(p *float32, _ *[]byte) C.float      { return C.float(*p) }
380        func refC_double(p *float64, _ *[]byte) C.double    { return C.double(*p) }
381        "#;
382        for item in self.file.items.iter() {
383            match item {
384                // for example, convert
385                // pub struct DemoRequest {
386                //     pub name: String,
387                //     pub age: u8,
388                // }
389                // to
390                // type DemoRequest struct {
391                //     name String
392                //     age uint8
393                // }
394                // func newDemoRequest(p C.DemoRequestRef) DemoRequest {
395                //     return DemoRequest {
396                //         name: newString(p.name),
397                //         age: uint8(p.age),
398                //     }
399                // }
400                // func refDemoRequest(p DemoRequest) C.DemoRequestRef {
401                //     return C.DemoRequestRef {
402                //         name: refString(p.name),
403                //         age: C.uint8_t(p.age),
404                //     }
405                // }
406                Item::Struct(s) => {
407                    let struct_name = s.ident.to_string();
408                    out.push_str(&format!("type {} struct {{\n", struct_name));
409                    for field in s.fields.iter() {
410                        let field_name = field
411                            .ident
412                            .as_ref()
413                            .ok_or_else(|| serr!("only named fields are supported"))?
414                            .to_string();
415                        let field_type = ParamType::try_from(&field.ty)?;
416                        out.push_str(&format!("    {} {}\n", field_name, field_type.to_go()));
417                    }
418                    out.push_str("}\n");
419
420                    // newStruct
421                    out.push_str(&format!(
422                        "func new{struct_name}(p C.{struct_name}Ref) {struct_name}{{\nreturn {struct_name}{{\n"
423                    ));
424                    for field in s.fields.iter() {
425                        let field_name = field.ident.as_ref().unwrap().to_string();
426                        let field_type = ParamType::try_from(&field.ty)?;
427                        let (new_f, _) = field_type.c_to_go_field_converter(levels);
428                        out.push_str(&format!("{field_name}: {new_f}(p.{field_name}),\n",));
429                    }
430                    out.push_str("}\n}\n");
431
432                    // ownStruct
433                    out.push_str(&format!(
434                        "func own{struct_name}(p C.{struct_name}Ref) {struct_name}{{\nreturn {struct_name}{{\n"
435                    ));
436                    for field in s.fields.iter() {
437                        let field_name = field.ident.as_ref().unwrap().to_string();
438                        let field_type = ParamType::try_from(&field.ty)?;
439                        let own_f = field_type.c_to_go_field_converter_owned();
440                        out.push_str(&format!("{field_name}: {own_f}(p.{field_name}),\n",));
441                    }
442                    out.push_str("}\n}\n");
443
444                    // cntStruct
445                    let level = *levels.get(&s.ident).unwrap();
446                    out.push_str(&format!(
447                        "func cnt{struct_name}(s *{struct_name}, cnt *uint) [0]C.{struct_name}Ref {{\n"
448                    ));
449                    let mut used = false;
450                    if level == 2 {
451                        for field in s.fields.iter() {
452                            let field_name = field.ident.as_ref().unwrap().to_string();
453                            let field_type = ParamType::try_from(&field.ty)?;
454                            let (counter_f, level) = field_type.go_to_c_field_counter(levels);
455                            if level == 2 {
456                                out.push_str(&format!("{counter_f}(&s.{field_name}, cnt)\n"));
457                                used = true;
458                            }
459                        }
460                    }
461                    if !used {
462                        out.push_str("_ = s\n_ = cnt\n");
463                    }
464                    out.push_str(&format!("return [0]C.{struct_name}Ref{{}}\n"));
465                    out.push_str("}\n");
466
467                    // refStruct
468                    out.push_str(&format!(
469                        "func ref{struct_name}(p *{struct_name}, buffer *[]byte) C.{struct_name}Ref{{\nreturn C.{struct_name}Ref{{\n"
470                    ));
471                    for field in s.fields.iter() {
472                        let field_name = field.ident.as_ref().unwrap().to_string();
473                        let field_type = ParamType::try_from(&field.ty)?;
474                        let (ref_f, _) = field_type.go_to_c_field_converter(levels);
475                        out.push_str(&format!(
476                            "{field_name}: {ref_f}(&p.{field_name}, buffer),\n",
477                        ));
478                    }
479                    out.push_str("}\n}\n");
480                }
481                _ => continue,
482            }
483        }
484        Ok(out)
485    }
486
487    pub fn convert_r2g_trait(&self) -> Result<Vec<R2GTraitRepr>> {
488        let out: Vec<R2GTraitRepr> = self
489            .file
490            .items
491            .iter()
492            .filter_map(|item| match item {
493                Item::Trait(t)
494                    if t.attrs
495                        .iter()
496                        .any(|attr| attr.meta.path().segments.last().unwrap().ident == "r2g") =>
497                {
498                    Some(t)
499                }
500                _ => None,
501            })
502            .map(|trat| trat.try_into())
503            .collect::<Result<Vec<R2GTraitRepr>>>()?;
504        Ok(out)
505    }
506
507    pub fn convert_g2r_trait(&self) -> Result<Vec<G2RTraitRepr>> {
508        let out: Vec<G2RTraitRepr> = self
509            .file
510            .items
511            .iter()
512            .filter_map(|item| match item {
513                Item::Trait(t)
514                    if t.attrs
515                        .iter()
516                        .any(|attr| attr.meta.path().segments.last().unwrap().ident == "g2r") =>
517                {
518                    Some(t)
519                }
520                _ => None,
521            })
522            .map(|trat| trat.try_into())
523            .collect::<Result<Vec<G2RTraitRepr>>>()?;
524        Ok(out)
525    }
526
527    // 0->Primitive
528    // 1->SimpleWrapper
529    // 2->Complex
530    pub fn convert_structs_levels(&self) -> Result<HashMap<Ident, u8>> {
531        enum Node {
532            List(Box<Node>),
533            NamedStruct(Ident),
534            Primitive,
535        }
536        fn type_to_node(ty: &Type) -> Result<Node> {
537            let seg = type_to_segment(ty)?;
538            match seg.ident.to_string().as_str() {
539                "Vec" => {
540                    let inside = match &seg.arguments {
541                        syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
542                            syn::GenericArgument::Type(ty) => ty,
543                            _ => panic!("list generic must be a type"),
544                        },
545                        _ => panic!("list type must have angle bracketed arguments"),
546                    };
547                    Ok(Node::List(Box::new(type_to_node(inside)?)))
548                }
549                "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64" | "isize"
550                | "bool" | "char" | "f32" | "f64" => Ok(Node::Primitive),
551                _ => Ok(Node::NamedStruct(seg.ident.clone())),
552            }
553        }
554        fn node_level(
555            node: &Node,
556            items: &HashMap<Ident, Vec<Node>>,
557            out: &mut HashMap<Ident, u8>,
558        ) -> u8 {
559            match node {
560                Node::List(inner) => (1 + node_level(inner, items, out)).min(2),
561                Node::NamedStruct(ident) if ident.to_string().as_str() == "String" => 1,
562                Node::NamedStruct(name) => {
563                    if let Some(lv) = out.get(name) {
564                        return *lv;
565                    }
566                    let lv = items
567                        .get(name)
568                        .map(|nodes| {
569                            nodes
570                                .iter()
571                                .map(|n| node_level(n, items, out))
572                                .max()
573                                .unwrap_or(0)
574                        })
575                        .unwrap();
576                    out.insert(name.clone(), lv);
577                    lv
578                }
579                Node::Primitive => 0,
580            }
581        }
582        let mut items = HashMap::<Ident, Vec<Node>>::new();
583        for item in self.file.items.iter() {
584            match item {
585                Item::Struct(s) => {
586                    let mut fields = Vec::new();
587                    for field in &s.fields {
588                        fields.push(type_to_node(&field.ty)?);
589                    }
590                    items.insert(s.ident.clone(), fields);
591                }
592                _ => continue,
593            }
594        }
595
596        let mut out = HashMap::new();
597        for name in items.keys() {
598            let lv = node_level(&Node::NamedStruct(name.clone()), &items, &mut out);
599            out.insert(name.clone(), lv);
600        }
601        out.insert(Ident::new("String", Span::call_site()), 1);
602        Ok(out)
603    }
604}
605
606pub struct Param {
607    pub name: Ident,
608    pub ty: ParamType,
609}
610
611impl Param {
612    pub fn ty(&self) -> &ParamType {
613        &self.ty
614    }
615}
616
617pub struct ParamType {
618    pub inner: ParamTypeInner,
619    pub is_reference: bool,
620}
621
622pub enum ParamTypeInner {
623    Primitive(Ident),
624    Custom(Ident),
625    List(Type),
626}
627
628impl ToTokens for ParamType {
629    fn to_tokens(&self, tokens: &mut TokenStream) {
630        if self.is_reference {
631            tokens.extend(quote! {&});
632        }
633        match &self.inner {
634            ParamTypeInner::Primitive(ty) => ty.to_tokens(tokens),
635            ParamTypeInner::Custom(ty) => ty.to_tokens(tokens),
636            ParamTypeInner::List(ty) => ty.to_tokens(tokens),
637        }
638    }
639}
640
641impl TryFrom<&Type> for ParamType {
642    type Error = Error;
643
644    fn try_from(mut ty: &Type) -> Result<Self> {
645        let mut is_reference = false;
646        if let Type::Reference(r) = ty {
647            is_reference = true;
648            ty = &r.elem;
649        }
650
651        // TypePath -> ParamType
652        let seg = type_to_segment(ty)?;
653        let param_type_inner = match seg.ident.to_string().as_str() {
654            "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "usize" | "isize"
655            | "bool" | "char" | "f32" | "f64" => {
656                if !seg.arguments.is_none() {
657                    sbail!("primitive types with arguments are not supported")
658                }
659                ParamTypeInner::Primitive(seg.ident.clone())
660            }
661            "Vec" => ParamTypeInner::List(ty.clone()),
662            _ => {
663                if !seg.arguments.is_none() {
664                    sbail!("custom types with arguments are not supported")
665                }
666                ParamTypeInner::Custom(seg.ident.clone())
667            }
668        };
669        Ok(ParamType {
670            inner: param_type_inner,
671            is_reference,
672        })
673    }
674}
675
676impl ParamType {
677    pub fn to_c(&self, with_struct: bool) -> String {
678        let struct_ = if with_struct { "struct " } else { "" };
679        match &self.inner {
680            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
681                "u8" => "uint8_t",
682                "u16" => "uint16_t",
683                "u32" => "uint32_t",
684                "u64" => "uint64_t",
685                "i8" => "int8_t",
686                "i16" => "int16_t",
687                "i32" => "int32_t",
688                "i64" => "int64_t",
689                "bool" => "bool",
690                "char" => "uint32_t",
691                "usize" => "uintptr_t",
692                "isize" => "intptr_t",
693                "f32" => "float",
694                "f64" => "double",
695                _ => panic!("unreconigzed rust primitive type {name}"),
696            }
697            .to_string(),
698            ParamTypeInner::Custom(c) => format!("{struct_}{c}Ref"),
699            ParamTypeInner::List(_) => format!("{struct_}ListRef"),
700        }
701    }
702
703    pub fn to_go(&self) -> String {
704        match &self.inner {
705            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
706                "u8" => "uint8",
707                "u16" => "uint16",
708                "u32" => "uint32",
709                "u64" => "uint64",
710                "i8" => "int8",
711                "i16" => "int16",
712                "i32" => "int32",
713                "i64" => "int64",
714                "bool" => "bool",
715                "char" => "rune",
716                "usize" => "uint",
717                "isize" => "int",
718                "f32" => "float32",
719                "f64" => "float64",
720                _ => panic!("unreconigzed rust primitive type {name}"),
721            }
722            .to_string(),
723            ParamTypeInner::Custom(c) => {
724                let s = c.to_string();
725                match s.as_str() {
726                    "String" => "string".to_string(),
727                    _ => s,
728                }
729            }
730            ParamTypeInner::List(inner) => {
731                let seg = type_to_segment(inner).unwrap();
732                let inside = match &seg.arguments {
733                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
734                        syn::GenericArgument::Type(ty) => ty,
735                        _ => panic!("list generic must be a type"),
736                    },
737                    _ => panic!("list type must have angle bracketed arguments"),
738                };
739                format!(
740                    "[]{}",
741                    ParamType::try_from(inside)
742                        .expect("unable to convert list type")
743                        .to_go()
744                )
745            }
746        }
747    }
748
749    // f: StructRef -> Struct
750    pub fn c_to_go_field_converter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
751        match &self.inner {
752            ParamTypeInner::Primitive(name) => (
753                match name.to_string().as_str() {
754                    "u8" => "newC_uint8_t",
755                    "u16" => "newC_uint16_t",
756                    "u32" => "newC_uint32_t",
757                    "u64" => "newC_uint64_t",
758                    "i8" => "newC_int8_t",
759                    "i16" => "newC_int16_t",
760                    "i32" => "newC_int32_t",
761                    "i64" => "newC_int64_t",
762                    "bool" => "newC_bool",
763                    "usize" => "newC_uintptr_t",
764                    "isize" => "newC_intptr_t",
765                    "f32" => "newC_float",
766                    "f64" => "newC_double",
767                    _ => panic!("unrecognized rust primitive type {name}"),
768                }
769                .to_string(),
770                0,
771            ),
772            ParamTypeInner::Custom(c) => (
773                format!("new{}", c.to_string().as_str()),
774                *mapping.get(c).unwrap(),
775            ),
776            ParamTypeInner::List(inner) => {
777                let seg = type_to_segment(inner).unwrap();
778                let inside = match &seg.arguments {
779                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
780                        syn::GenericArgument::Type(ty) => ty,
781                        _ => panic!("list generic must be a type"),
782                    },
783                    _ => panic!("list type must have angle bracketed arguments"),
784                };
785                let (inner, inner_level) = ParamType::try_from(inside)
786                    .expect("unable to convert list type")
787                    .c_to_go_field_converter(mapping);
788                if inner_level == 0 {
789                    (format!("new_list_mapper_primitive({inner})"), 1)
790                } else {
791                    (format!("new_list_mapper({inner})"), 2.min(inner_level + 1))
792                }
793            }
794        }
795    }
796
797    // f: StructRef -> Struct with fully ownership
798    pub fn c_to_go_field_converter_owned(&self) -> String {
799        match &self.inner {
800            ParamTypeInner::Primitive(name) => match name.to_string().as_str() {
801                "u8" => "newC_uint8_t",
802                "u16" => "newC_uint16_t",
803                "u32" => "newC_uint32_t",
804                "u64" => "newC_uint64_t",
805                "i8" => "newC_int8_t",
806                "i16" => "newC_int16_t",
807                "i32" => "newC_int32_t",
808                "i64" => "newC_int64_t",
809                "bool" => "newC_bool",
810                "usize" => "newC_uintptr_t",
811                "isize" => "newC_intptr_t",
812                "f32" => "newC_float",
813                "f64" => "newC_double",
814                _ => panic!("unrecognized rust primitive type {name}"),
815            }
816            .to_string(),
817            ParamTypeInner::Custom(c) => format!("own{}", c.to_string().as_str()),
818            ParamTypeInner::List(inner) => {
819                let seg = type_to_segment(inner).unwrap();
820                let inside = match &seg.arguments {
821                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
822                        syn::GenericArgument::Type(ty) => ty,
823                        _ => panic!("list generic must be a type"),
824                    },
825                    _ => panic!("list type must have angle bracketed arguments"),
826                };
827                let inner = ParamType::try_from(inside)
828                    .expect("unable to convert list type")
829                    .c_to_go_field_converter_owned();
830                format!("new_list_mapper({inner})")
831            }
832        }
833    }
834
835    pub fn go_to_c_field_counter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
836        match &self.inner {
837            ParamTypeInner::Primitive(name) => (
838                match name.to_string().as_str() {
839                    "u8" => "cntC_uint8_t",
840                    "u16" => "cntC_uint16_t",
841                    "u32" => "cntC_uint32_t",
842                    "u64" => "cntC_uint64_t",
843                    "i8" => "cntC_int8_t",
844                    "i16" => "cntC_int16_t",
845                    "i32" => "cntC_int32_t",
846                    "i64" => "cntC_int64_t",
847                    "bool" => "cntC_bool",
848                    "usize" => "cntC_uintptr_t",
849                    "isize" => "cntC_intptr_t",
850                    "f32" => "cntC_float",
851                    "f64" => "cntC_double",
852                    _ => panic!("unrecognized rust primitive type {name}"),
853                }
854                .to_string(),
855                0,
856            ),
857            ParamTypeInner::Custom(c) => (
858                format!("cnt{}", c.to_string().as_str()),
859                *mapping.get(c).unwrap(),
860            ),
861            ParamTypeInner::List(inner) => {
862                let seg = type_to_segment(inner).unwrap();
863                let inside = match &seg.arguments {
864                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
865                        syn::GenericArgument::Type(ty) => ty,
866                        _ => panic!("list generic must be a type"),
867                    },
868                    _ => panic!("list type must have angle bracketed arguments"),
869                };
870                let (inner, inner_level) = ParamType::try_from(inside)
871                    .expect("unable to convert list type")
872                    .go_to_c_field_counter(mapping);
873                if inner_level == 0 {
874                    (format!("cnt_list_mapper_primitive({inner})"), 1)
875                } else {
876                    (format!("cnt_list_mapper({inner})"), 2.min(inner_level + 1))
877                }
878            }
879        }
880    }
881
882    // f: Struct -> StructRef
883    pub fn go_to_c_field_converter(&self, mapping: &HashMap<Ident, u8>) -> (String, u8) {
884        match &self.inner {
885            ParamTypeInner::Primitive(name) => (
886                match name.to_string().as_str() {
887                    "u8" => "refC_uint8_t",
888                    "u16" => "refC_uint16_t",
889                    "u32" => "refC_uint32_t",
890                    "u64" => "refC_uint64_t",
891                    "i8" => "refC_int8_t",
892                    "i16" => "refC_int16_t",
893                    "i32" => "refC_int32_t",
894                    "i64" => "refC_int64_t",
895                    "bool" => "refC_bool",
896                    "usize" => "refC_uintptr_t",
897                    "isize" => "refC_intptr_t",
898                    "f32" => "refC_float",
899                    "f64" => "refC_double",
900                    _ => panic!("unreconigzed rust primitive type {name}"),
901                }
902                .to_string(),
903                0,
904            ),
905            ParamTypeInner::Custom(c) => (
906                format!("ref{}", c.to_string().as_str()),
907                *mapping.get(c).unwrap(),
908            ),
909            ParamTypeInner::List(inner) => {
910                let seg = type_to_segment(inner).unwrap();
911                let inside = match &seg.arguments {
912                    syn::PathArguments::AngleBracketed(ga) => match ga.args.last().unwrap() {
913                        syn::GenericArgument::Type(ty) => ty,
914                        _ => panic!("list generic must be a type"),
915                    },
916                    _ => panic!("list type must have angle bracketed arguments"),
917                };
918                let (inner, inner_level) = ParamType::try_from(inside)
919                    .expect("unable to convert list type")
920                    .go_to_c_field_converter(mapping);
921                if inner_level == 0 {
922                    (format!("ref_list_mapper_primitive({inner})"), 1)
923                } else {
924                    (format!("ref_list_mapper({inner})"), 2.min(inner_level + 1))
925                }
926            }
927        }
928    }
929
930    pub fn to_rust_ref(&self, prefix: Option<&TokenStream>) -> TokenStream {
931        match &self.inner {
932            ParamTypeInner::Primitive(name) => quote!(#name),
933            ParamTypeInner::Custom(name) => {
934                let ident = format_ident!("{}Ref", name);
935                quote!(#prefix #ident)
936            }
937            ParamTypeInner::List(_) => {
938                let ident = format_ident!("ListRef");
939                quote!(#prefix #ident)
940            }
941        }
942    }
943}
944
945pub(crate) fn type_to_segment(ty: &Type) -> Result<&PathSegment> {
946    let field_type = match ty {
947        Type::Path(p) => p,
948        _ => sbail!("only path types are supported"),
949    };
950    let path = &field_type.path;
951    // Leading colon is not allow
952    if path.leading_colon.is_some() {
953        sbail!("types with leading colons are not supported");
954    }
955    // We only accept single-segment path
956    if path.segments.len() != 1 {
957        sbail!("types with multiple segments are not supported");
958    }
959    Ok(path.segments.first().unwrap())
960}
961
962#[cfg(test)]
963mod tests {
964    #[test]
965    fn it_works() {
966        let raw = r#"
967        pub struct DemoRequest {
968            pub name: String,
969            pub age: u8,
970        }
971        pub struct DemoResponse {
972            pub pass: bool,
973        }
974        pub trait DemoCall {
975            fn demo_check(req: DemoRequest) -> DemoResponse;
976            fn demo_check_async(req: DemoRequest) -> impl std::future::Future<Output = DemoResponse>;
977        }
978        "#;
979        let raw_file = super::RawRsFile::new(raw);
980        let traits = raw_file.convert_r2g_trait().unwrap();
981        let levels = raw_file.convert_structs_levels().unwrap();
982
983        println!(
984            "structs gen: {}",
985            raw_file.convert_structs_to_go(&levels, false).unwrap()
986        );
987        for trait_ in traits {
988            println!("if gen: {}", trait_.generate_go_interface());
989            println!("go export gen: {}", trait_.generate_go_exports(&levels));
990        }
991        let levels = raw_file.convert_structs_levels().unwrap();
992        levels.iter().for_each(|f| println!("{}: {}", f.0, f.1));
993    }
994}