scheme_rs_macros/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro2::Span;
3use quote::quote;
4use syn::{
5    parse_macro_input, parse_quote, punctuated::Punctuated, DataEnum, DataStruct, DeriveInput,
6    Fields, FnArg, GenericParam, Generics, Ident, ItemFn, Member, PatType, Token, Type,
7};
8
9#[proc_macro_attribute]
10pub fn builtin(name: TokenStream, item: TokenStream) -> TokenStream {
11    let name = proc_macro2::TokenStream::from(name);
12    let builtin = parse_macro_input!(item as ItemFn);
13
14    let impl_name = builtin.sig.ident.clone();
15    let wrapper_name = impl_name.to_string() + "_wrapper";
16    let wrapper_name = Ident::new(&wrapper_name, Span::call_site());
17
18    let is_variadic = if let Some(last_arg) = builtin.sig.inputs.last() {
19        is_vec(last_arg)
20    } else {
21        false
22    };
23
24    let num_args = if is_variadic {
25        builtin.sig.inputs.len().saturating_sub(2)
26    } else {
27        builtin.sig.inputs.len() - 1
28    };
29
30    let wrapper: ItemFn = if !is_variadic {
31        let arg_indices: Vec<_> = (0..num_args).collect();
32        parse_quote! {
33            fn #wrapper_name(
34                cont: Option<std::sync::Arc<::scheme_rs::continuation::Continuation>>,
35                args: Vec<::scheme_rs::gc::Gc<::scheme_rs::value::Value>>
36            ) -> futures::future::BoxFuture<'static, Result<Vec<::scheme_rs::gc::Gc<::scheme_rs::value::Value>>, ::scheme_rs::error::RuntimeError>> {
37                Box::pin(
38                    async move {
39                        #impl_name(
40                            &cont,
41                            #( &args[#arg_indices], )*
42                        ).await
43                    }
44                )
45            }
46        }
47    } else {
48        let arg_indices: Vec<_> = (0..num_args).collect();
49        parse_quote! {
50            fn #wrapper_name(
51                cont: Option<std::sync::Arc<::scheme_rs::continuation::Continuation>>,
52                mut required_args: Vec<::scheme_rs::gc::Gc<::scheme_rs::value::Value>>
53            ) -> futures::future::BoxFuture<'static, Result<Vec<::scheme_rs::gc::Gc<::scheme_rs::value::Value>>, ::scheme_rs::error::RuntimeError>> {
54                let var_args = required_args.split_off(#num_args);
55                Box::pin(
56                    async move {
57                        #impl_name(
58                            &cont,
59                            #( &required_args[#arg_indices], )*
60                            var_args
61                        ).await
62                    }
63                )
64            }
65        }
66    };
67    quote! {
68        #builtin
69
70        #wrapper
71
72        inventory::submit! {
73            ::scheme_rs::builtin::Builtin::new(#name, #num_args, #is_variadic, #wrapper_name)
74        }
75    }
76    .into()
77}
78
79fn is_vec(arg: &FnArg) -> bool {
80    if let FnArg::Typed(PatType { ty, .. }) = arg {
81        if let Type::Path(ref path) = ty.as_ref() {
82            return path
83                .path
84                .segments
85                .last()
86                .map(|p| p.ident.to_string())
87                .as_deref()
88                == Some("Vec");
89        }
90    }
91    false
92}
93
94#[proc_macro_derive(Trace)]
95pub fn derive_trace(input: TokenStream) -> TokenStream {
96    let DeriveInput {
97        ident,
98        data,
99        generics,
100        ..
101    } = parse_macro_input!(input);
102
103    match data {
104        syn::Data::Struct(data_struct) => derive_trace_struct(ident, data_struct, generics).into(),
105        syn::Data::Enum(data_enum) => derive_trace_enum(ident, data_enum).into(),
106        _ => panic!("Union types are not supported."),
107    }
108}
109
110fn derive_trace_struct(
111    name: Ident,
112    record: DataStruct,
113    generics: Generics,
114) -> proc_macro2::TokenStream {
115    let fields = match record.fields {
116        Fields::Named(fields) => fields.named,
117        Fields::Unnamed(fields) => fields.unnamed,
118        _ => {
119            return quote! {
120                unsafe impl ::scheme_rs::gc::Trace for #name {
121                    unsafe fn visit_children(&self, visitor: fn(::scheme_rs::gc::OpaqueGcPtr)) {}
122                }
123            }
124        }
125    };
126
127    let Generics {
128        mut params,
129        where_clause,
130        ..
131    } = generics;
132
133    let mut unbound_params = Punctuated::<GenericParam, Token![,]>::new();
134
135    for param in params.iter_mut() {
136        match param {
137            GenericParam::Type(ref mut ty) => {
138                ty.bounds.push(syn::TypeParamBound::Verbatim(
139                    quote! { ::scheme_rs::gc::Trace },
140                ));
141                unbound_params.push(GenericParam::Type(syn::TypeParam::from(ty.ident.clone())));
142            }
143            param => unbound_params.push(param.clone()),
144        }
145    }
146
147    let field_visits = fields
148        .iter()
149        .enumerate()
150        .map(|(i, f)| {
151            let ident = f.ident.clone().map_or_else(
152                || {
153                    Member::Unnamed(syn::Index {
154                        index: i as u32,
155                        span: Span::call_site(),
156                    })
157                },
158                Member::Named,
159            );
160            if is_gc(&f.ty) {
161                quote! {
162                    visitor(self.#ident.as_opaque());
163                }
164            } else {
165                quote! {
166                    self. #ident .visit_children(visitor);
167                }
168            }
169        })
170        .collect::<Vec<_>>();
171
172    let field_drops = fields
173        .iter()
174        .enumerate()
175        .flat_map(|(i, f)| {
176            let ident = f.ident.clone().map_or_else(
177                || {
178                    Member::Unnamed(syn::Index {
179                        index: i as u32,
180                        span: Span::call_site(),
181                    })
182                },
183                Member::Named,
184            );
185            if !is_gc(&f.ty) {
186                Some(quote! {
187                        self.#ident.finalize();
188                })
189            } else {
190                None
191            }
192        })
193        .collect::<Vec<_>>();
194
195    quote! {
196        #[automatically_derived]
197        unsafe impl<#params> ::scheme_rs::gc::Trace for #name <#unbound_params>
198        #where_clause
199        {
200            unsafe fn visit_children(&self, visitor: fn(::scheme_rs::gc::OpaqueGcPtr)) {
201                #(
202                    #field_visits
203                )*
204            }
205
206            unsafe fn finalize(&mut self) {
207                #(
208                    #field_drops
209                )*
210            }
211        }
212    }
213}
214
215// TODO: Add generics here
216fn derive_trace_enum(name: Ident, data_enum: DataEnum) -> proc_macro2::TokenStream {
217    let (visit_match_clauses, finalize_match_clauses): (Vec<_>, Vec<_>) = data_enum
218        .variants
219        .into_iter()
220        .flat_map(|variant| {
221            let fields: Vec<_> = match variant.fields {
222                Fields::Named(ref named) => named
223                    .named
224                    .iter()
225                    .map(|field| (field.ty.clone(), field.ident.as_ref().unwrap().clone()))
226                    .collect(),
227                Fields::Unnamed(ref unnamed) => unnamed
228                    .unnamed
229                    .iter()
230                    .enumerate()
231                    .map(|(i, field)| {
232                        let ident = Ident::new(&format!("t{i}"), Span::call_site());
233                        (field.ty.clone(), ident)
234                    })
235                    .collect(),
236                _ => return None,
237            };
238            let visits: Vec<_> = fields
239                .iter()
240                .map(|(ty, accessor)| {
241                    if is_gc(ty) {
242                        quote! {
243                            visitor(#accessor.as_opaque())
244                        }
245                    } else {
246                        quote! {
247                            #accessor.visit_children(visitor)
248                        }
249                    }
250                })
251                .collect();
252            let drops: Vec<_> = fields
253                .iter()
254                .filter(|(ty, _)| !is_gc(ty))
255                .map(|(_, accessor)| {
256                    quote! {
257                        #accessor.finalize();
258                    }
259                })
260                .collect();
261            let field_name = fields.iter().map(|(_, field)| field);
262            let fields_destructured = match variant.fields {
263                Fields::Named(..) => quote! { { #( ref #field_name, )* .. } },
264                _ => quote! { ( #( ref #field_name ),* ) },
265            };
266            let field_name = fields.iter().map(|(_, field)| field);
267            let fields_destructured_mut = match variant.fields {
268                Fields::Named(..) => quote! { { #( ref mut #field_name, )* .. } },
269                _ => quote! { ( #( ref mut #field_name ),* ) },
270            };
271            let variant_name = variant.ident;
272            Some((
273                quote! {
274                    Self::#variant_name #fields_destructured => {
275                        #(
276                            #visits;
277                        )*
278                    }
279                },
280                quote! {
281                    Self::#variant_name #fields_destructured_mut => {
282                        #(
283                            #drops
284                        )*
285                    }
286                },
287            ))
288        })
289        .unzip();
290    quote! {
291        unsafe impl ::scheme_rs::gc::Trace for #name {
292            unsafe fn visit_children(&self, visitor: fn(::scheme_rs::gc::OpaqueGcPtr)) {
293                match self {
294                    #( #visit_match_clauses )*,
295                    _ => (),
296                }
297            }
298
299            unsafe fn finalize(&mut self) {
300                match self {
301                    #( #finalize_match_clauses )*,
302                    _ => (),
303                }
304            }
305        }
306    }
307}
308
309fn is_gc(arg: &Type) -> bool {
310    if let Type::Path(ref path) = arg {
311        return path
312            .path
313            .segments
314            .last()
315            .map(|p| p.ident.to_string())
316            .as_deref()
317            == Some("Gc");
318    }
319    false
320}