1use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use rust2go_common::{g2r::G2RTraitRepr, r2g::R2GTraitRepr, sbail};
6use syn::{parse::Parser, parse_macro_input, DeriveInput, Ident};
7
8#[proc_macro_derive(R2G)]
9pub fn r2g_derive(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 if !input.generics.params.is_empty() {
13 return TokenStream::default();
14 }
15 let data = match input.data {
17 syn::Data::Struct(d) => d,
18 _ => return TokenStream::default(),
19 };
20 let type_name = input.ident;
21 let type_name_str = type_name.to_string();
22
23 let ref_type_name = Ident::new(&format!("{type_name_str}Ref"), type_name.span());
24 let mut ref_fields = Vec::with_capacity(data.fields.len());
25 for field in data.fields.iter() {
26 let name = field.ident.as_ref().unwrap();
27 let ty = &field.ty;
28 let syn::Type::Path(path) = ty else {
29 return TokenStream::default();
30 };
31 let Some(first_seg) = path.path.segments.first() else {
32 return TokenStream::default();
33 };
34 match first_seg.ident.to_string().as_str() {
35 "Vec" => {
36 ref_fields.push(quote! {#name: ::rust2go::ListRef});
37 }
38 "String" => {
39 ref_fields.push(quote! {#name: ::rust2go::StringRef});
40 }
41 "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize"
42 | "f32" | "f64" | "bool" | "char" => {
43 ref_fields.push(quote! {#name: #ty});
44 }
45 ty => {
46 let ref_type = format_ident!("{ty}Ref");
47 ref_fields.push(quote! {#name: #ref_type});
48 }
49 }
50 }
51
52 let mut owned_names = Vec::with_capacity(data.fields.len());
53 let mut owned_types = Vec::with_capacity(data.fields.len());
54 for field in data.fields.iter() {
55 owned_names.push(field.ident.clone().unwrap());
56 owned_types.push(field.ty.clone());
57 }
58
59 let expanded = quote! {
60 #[repr(C)]
61 pub struct #ref_type_name {
62 #(#ref_fields),*
63 }
64
65 impl ::rust2go::ToRef for #type_name {
66 const MEM_TYPE: ::rust2go::MemType = ::rust2go::max_mem_type!(#(#owned_types),*);
67 type Ref = #ref_type_name;
68
69 fn to_size(&self, acc: &mut usize) {
70 if matches!(Self::MEM_TYPE, ::rust2go::MemType::Complex) {
71 #(self.#owned_names.to_size(acc);)*
72 }
73 }
74
75 fn to_ref(&self, buffer: &mut ::rust2go::Writer) -> Self::Ref {
76 #ref_type_name {
77 #(#owned_names: ::rust2go::ToRef::to_ref(&self.#owned_names, buffer),)*
78 }
79 }
80 }
81
82 impl ::rust2go::FromRef for #type_name {
83 type Ref = #ref_type_name;
84
85 fn from_ref(ref_: &Self::Ref) -> Self {
86 Self {
87 #(#owned_names: ::rust2go::FromRef::from_ref(&ref_.#owned_names),)*
88 }
89 }
90 }
91 };
92 TokenStream::from(expanded)
93}
94
95fn parse_attrs(attrs: TokenStream) -> (Option<syn::Path>, Option<usize>) {
96 let mut binding_path = None;
97 let mut queue_size = None;
98
99 type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
100 if let Ok(attrs) = AttributeArgs::parse_terminated.parse(attrs) {
101 for attr in attrs {
102 match attr {
103 syn::Meta::NameValue(nv) => {
104 if nv.path.is_ident("binding") {
105 binding_path = Some(nv.path);
106 } else if nv.path.is_ident("queue_size") {
107 if let syn::Expr::Lit(syn::ExprLit {
108 lit: syn::Lit::Int(litint),
109 ..
110 }) = nv.value
111 {
112 queue_size = Some(litint.base10_parse::<usize>().unwrap());
113 }
114 }
115 }
116 syn::Meta::Path(p) => {
117 binding_path = Some(p);
118 }
119 _ => {}
120 }
121 }
122 }
123 (binding_path, queue_size)
124}
125
126#[proc_macro_attribute]
127pub fn r2g(attrs: TokenStream, item: TokenStream) -> TokenStream {
128 let (binding_path, queue_size) = parse_attrs(attrs);
129 syn::parse::<syn::ItemTrait>(item)
130 .and_then(|trat| r2g_trait(binding_path, queue_size, trat))
131 .unwrap_or_else(|e| TokenStream::from(e.to_compile_error()))
132}
133
134#[proc_macro_attribute]
135pub fn g2r(_attrs: TokenStream, item: TokenStream) -> TokenStream {
136 syn::parse::<syn::ItemTrait>(item)
137 .and_then(g2r_trait)
138 .unwrap_or_else(|e| TokenStream::from(e.to_compile_error()))
139}
140
141fn g2r_trait(mut trat: syn::ItemTrait) -> syn::Result<TokenStream> {
142 let trat_repr = G2RTraitRepr::try_from(&trat)?;
143
144 for trat_fn in trat.items.iter_mut() {
145 match trat_fn {
146 syn::TraitItem::Fn(f) => {
147 f.attrs.clear();
149 }
150 _ => sbail!("only fn is supported"),
151 }
152 }
153
154 let mut out = quote! {#trat};
155 out.extend(trat_repr.generate_rs()?);
156 Ok(out.into())
157}
158
159fn r2g_trait(
160 binding_path: Option<syn::Path>,
161 queue_size: Option<usize>,
162 mut trat: syn::ItemTrait,
163) -> syn::Result<TokenStream> {
164 let trat_repr = R2GTraitRepr::try_from(&trat)?;
165
166 for (fn_repr, trat_fn) in trat_repr.fns().iter().zip(trat.items.iter_mut()) {
167 match trat_fn {
168 syn::TraitItem::Fn(f) => {
169 f.attrs.clear();
171
172 if fn_repr.ret().is_none() && !fn_repr.is_async() && fn_repr.mem_call_id().is_some()
174 {
175 f.sig.unsafety = Some(syn::token::Unsafe::default());
176 }
177
178 if fn_repr.is_async() {
180 let orig = match fn_repr.ret() {
181 None => quote! { () },
182 Some(ret) => quote! { #ret },
183 };
184 let auto_t = match (fn_repr.ret_send(), fn_repr.ret_static()) {
185 (true, true) => quote!( + Send + Sync + 'static),
186 (true, false) => quote!( + Send + Sync),
187 (false, true) => quote!( + 'static),
188 (false, false) => quote!(),
189 };
190 f.sig.asyncness = None;
191 if fn_repr.drop_safe_ret_params() {
192 let tys = fn_repr.params().iter().map(|p| p.ty());
194 f.sig.output = syn::parse_quote! { -> impl ::std::future::Future<Output = (#orig, (#(#tys,)*))> #auto_t };
195 } else {
196 f.sig.output = syn::parse_quote! { -> impl ::std::future::Future<Output = #orig> #auto_t };
197 }
198
199 if !fn_repr.is_safe() {
201 f.sig.unsafety = Some(syn::token::Unsafe::default());
202 }
203 }
204 }
205 _ => sbail!("only fn is supported"),
206 }
207 }
208
209 let mut out = quote! {#trat};
210 out.extend(trat_repr.generate_rs(binding_path.as_ref(), queue_size)?);
211 Ok(out.into())
212}