rust2go_macro/
lib.rs

1// Copyright 2024 ihciah. All Rights Reserved.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use rust2go_common::{g2r::G2RTraitRepr, r2g::R2GTraitRepr, sbail};
6use syn::{parse::Parser, parse_macro_input, DeriveInput, Ident};
7
8#[proc_macro_derive(R2G)]
9pub fn r2g_derive(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    // Skip derive when the type has generics.
12    if !input.generics.params.is_empty() {
13        return TokenStream::default();
14    }
15    // Skip derive when the type is not struct.
16    let data = match input.data {
17        syn::Data::Struct(d) => d,
18        _ => return TokenStream::default(),
19    };
20    let type_name = input.ident;
21    let type_name_str = type_name.to_string();
22
23    let ref_type_name = Ident::new(&format!("{type_name_str}Ref"), type_name.span());
24    let mut ref_fields = Vec::with_capacity(data.fields.len());
25    for field in data.fields.iter() {
26        let name = field.ident.as_ref().unwrap();
27        let ty = &field.ty;
28        let syn::Type::Path(path) = ty else {
29            return TokenStream::default();
30        };
31        let Some(first_seg) = path.path.segments.first() else {
32            return TokenStream::default();
33        };
34        match first_seg.ident.to_string().as_str() {
35            "Vec" => {
36                ref_fields.push(quote! {#name: ::rust2go::ListRef});
37            }
38            "String" => {
39                ref_fields.push(quote! {#name: ::rust2go::StringRef});
40            }
41            "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize"
42            | "f32" | "f64" | "bool" | "char" => {
43                ref_fields.push(quote! {#name: #ty});
44            }
45            ty => {
46                let ref_type = format_ident!("{ty}Ref");
47                ref_fields.push(quote! {#name: #ref_type});
48            }
49        }
50    }
51
52    let mut owned_names = Vec::with_capacity(data.fields.len());
53    let mut owned_types = Vec::with_capacity(data.fields.len());
54    for field in data.fields.iter() {
55        owned_names.push(field.ident.clone().unwrap());
56        owned_types.push(field.ty.clone());
57    }
58
59    let expanded = quote! {
60        #[repr(C)]
61        pub struct #ref_type_name {
62            #(#ref_fields),*
63        }
64
65        impl ::rust2go::ToRef for #type_name {
66            const MEM_TYPE: ::rust2go::MemType = ::rust2go::max_mem_type!(#(#owned_types),*);
67            type Ref = #ref_type_name;
68
69            fn to_size(&self, acc: &mut usize) {
70                if matches!(Self::MEM_TYPE, ::rust2go::MemType::Complex) {
71                    #(self.#owned_names.to_size(acc);)*
72                }
73            }
74
75            fn to_ref(&self, buffer: &mut ::rust2go::Writer) -> Self::Ref {
76                #ref_type_name {
77                    #(#owned_names: ::rust2go::ToRef::to_ref(&self.#owned_names, buffer),)*
78                }
79            }
80        }
81
82        impl ::rust2go::FromRef for #type_name {
83            type Ref = #ref_type_name;
84
85            fn from_ref(ref_: &Self::Ref) -> Self {
86                Self {
87                    #(#owned_names: ::rust2go::FromRef::from_ref(&ref_.#owned_names),)*
88                }
89            }
90        }
91    };
92    TokenStream::from(expanded)
93}
94
95fn parse_attrs(attrs: TokenStream) -> (Option<syn::Path>, Option<usize>) {
96    let mut binding_path = None;
97    let mut queue_size = None;
98
99    type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
100    if let Ok(attrs) = AttributeArgs::parse_terminated.parse(attrs) {
101        for attr in attrs {
102            match attr {
103                syn::Meta::NameValue(nv) => {
104                    if nv.path.is_ident("binding") {
105                        binding_path = Some(nv.path);
106                    } else if nv.path.is_ident("queue_size") {
107                        if let syn::Expr::Lit(syn::ExprLit {
108                            lit: syn::Lit::Int(litint),
109                            ..
110                        }) = nv.value
111                        {
112                            queue_size = Some(litint.base10_parse::<usize>().unwrap());
113                        }
114                    }
115                }
116                syn::Meta::Path(p) => {
117                    binding_path = Some(p);
118                }
119                _ => {}
120            }
121        }
122    }
123    (binding_path, queue_size)
124}
125
126#[proc_macro_attribute]
127pub fn r2g(attrs: TokenStream, item: TokenStream) -> TokenStream {
128    let (binding_path, queue_size) = parse_attrs(attrs);
129    syn::parse::<syn::ItemTrait>(item)
130        .and_then(|trat| r2g_trait(binding_path, queue_size, trat))
131        .unwrap_or_else(|e| TokenStream::from(e.to_compile_error()))
132}
133
134#[proc_macro_attribute]
135pub fn g2r(_attrs: TokenStream, item: TokenStream) -> TokenStream {
136    syn::parse::<syn::ItemTrait>(item)
137        .and_then(g2r_trait)
138        .unwrap_or_else(|e| TokenStream::from(e.to_compile_error()))
139}
140
141fn g2r_trait(mut trat: syn::ItemTrait) -> syn::Result<TokenStream> {
142    let trat_repr = G2RTraitRepr::try_from(&trat)?;
143
144    for trat_fn in trat.items.iter_mut() {
145        match trat_fn {
146            syn::TraitItem::Fn(f) => {
147                // remove attributes of all functions
148                f.attrs.clear();
149            }
150            _ => sbail!("only fn is supported"),
151        }
152    }
153
154    let mut out = quote! {#trat};
155    out.extend(trat_repr.generate_rs()?);
156    Ok(out.into())
157}
158
159fn r2g_trait(
160    binding_path: Option<syn::Path>,
161    queue_size: Option<usize>,
162    mut trat: syn::ItemTrait,
163) -> syn::Result<TokenStream> {
164    let trat_repr = R2GTraitRepr::try_from(&trat)?;
165
166    for (fn_repr, trat_fn) in trat_repr.fns().iter().zip(trat.items.iter_mut()) {
167        match trat_fn {
168            syn::TraitItem::Fn(f) => {
169                // remove attributes of all functions
170                f.attrs.clear();
171
172                // for shm based oneway call, add unsafe
173                if fn_repr.ret().is_none() && !fn_repr.is_async() && fn_repr.mem_call_id().is_some()
174                {
175                    f.sig.unsafety = Some(syn::token::Unsafe::default());
176                }
177
178                // convert async fn return impl future
179                if fn_repr.is_async() {
180                    let orig = match fn_repr.ret() {
181                        None => quote! { () },
182                        Some(ret) => quote! { #ret },
183                    };
184                    let auto_t = match (fn_repr.ret_send(), fn_repr.ret_static()) {
185                        (true, true) => quote!( + Send + Sync + 'static),
186                        (true, false) => quote!( + Send + Sync),
187                        (false, true) => quote!( + 'static),
188                        (false, false) => quote!(),
189                    };
190                    f.sig.asyncness = None;
191                    if fn_repr.drop_safe_ret_params() {
192                        // for all functions with #[drop_safe_ret], change the return type.
193                        let tys = fn_repr.params().iter().map(|p| p.ty());
194                        f.sig.output = syn::parse_quote! { -> impl ::std::future::Future<Output = (#orig, (#(#tys,)*))> #auto_t };
195                    } else {
196                        f.sig.output = syn::parse_quote! { -> impl ::std::future::Future<Output = #orig> #auto_t };
197                    }
198
199                    // for all functions with safe=false, add unsafe
200                    if !fn_repr.is_safe() {
201                        f.sig.unsafety = Some(syn::token::Unsafe::default());
202                    }
203                }
204            }
205            _ => sbail!("only fn is supported"),
206        }
207    }
208
209    let mut out = quote! {#trat};
210    out.extend(trat_repr.generate_rs(binding_path.as_ref(), queue_size)?);
211    Ok(out.into())
212}