struct_split_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Ident, Data, Fields, Path};
4use itertools::Itertools;
5use proc_macro2::{Span};
6use proc_macro2 as pm;
7
8
9// =============
10// === Utils ===
11// =============
12
13/// Get the current crate name;
14fn crate_name() -> Ident {
15    let macro_lib = env!("CARGO_PKG_NAME");
16    let suffix = "-macro";
17    if !macro_lib.ends_with(suffix) { panic!("Internal error.") }
18    let crate_name = &macro_lib[..macro_lib.len() - suffix.len()].replace('-',"_");
19    Ident::new(crate_name, Span::call_site())
20}
21
22/// Extract the module macro attribute.
23fn extract_module_attr(input: &DeriveInput) -> Path {
24    let mut module: Option<Path> = None;
25    for attr in &input.attrs {
26        if attr.path().is_ident("module") {
27            let tokens = attr.meta.require_list().unwrap().tokens.clone();
28            if let Ok(path) = syn::parse2::<Path>(tokens) {
29                module = Some(path);
30            }
31        }
32    }
33    module.expect("The 'module' attribute is required.")
34}
35
36
37// =============
38// === Macro ===
39// =============
40
41/// Derive impl. Comments in the code show expansion of the following example struct:
42/// ```ignore
43/// pub struct Ctx {
44///     geometry: GeometryCtx,
45///     material: MaterialCtx,
46///     mesh: MeshCtx,
47///     scene: SceneCtx,
48/// }
49/// ```
50#[proc_macro_derive(Split, attributes(module))]
51pub fn split_derive(input: TokenStream) -> TokenStream {
52    let lib = crate_name();
53    let input = parse_macro_input!(input as DeriveInput);
54    let module = extract_module_attr(&input);
55
56    let struct_ident = input.ident;
57    let ref_struct_ident = Ident::new(&format!("{struct_ident}Ref"), struct_ident.span());
58
59    let fields = if let Data::Struct(data) = &input.data {
60        if let Fields::Named(fields) = &data.fields {
61            fields.named.iter().collect::<Vec<_>>()
62        } else {
63            Vec::new()
64        }
65    } else {
66        Vec::new()
67    };
68
69    let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect_vec();
70    let field_types = fields.iter().map(|f| &f.ty).collect_vec();
71    let params = field_idents.iter().map(|i| Ident::new(&i.to_string(), i.span())).collect_vec();
72    let bounds_params_access = quote! { #(#params: #lib::Access,)* };
73    let field_values = field_types.iter().zip(params.iter()).map(|(field_ty, param)| {
74        quote! { #lib::Value<'_t, #param, #field_ty> }
75    }).collect_vec();
76
77    // Generates:
78    // #[repr(C)]
79    // pub struct CtxRef<'t, geometry, material, mesh, scene> {
80    //     geometry: Value<'t, geometry, GeometryCtx>,
81    //     material: Value<'t, material, MaterialCtx>,
82    //     mesh: Value<'t, mesh, MeshCtx>,
83    //     scene: Value<'t, scene, SceneCtx>,
84    // }
85    let ref_struct = quote! {
86        #[derive(Debug)]
87        #[repr(C)]
88        #[allow(non_camel_case_types)]
89        pub struct #ref_struct_ident<'_t, #(#params),*>
90        where #bounds_params_access {
91            #(pub #field_idents : #field_values),*
92        }
93    };
94
95    // Generates:
96    // impl<'t, geometry, material, mesh, scene>
97    //     AsRefs<'t, CtxRef<'t, geometry, material, mesh, scene>> for Ctx
98    // where
99    //     geometry:    Access,
100    //     material:    Access,
101    //     mesh:        Access,
102    //     scene:       Access,
103    //     GeometryCtx: RefCast<'t, Value<'t, geometry, GeometryCtx>>,
104    //     MaterialCtx: RefCast<'t, Value<'t, material, MaterialCtx>>,
105    //     MeshCtx:     RefCast<'t, Value<'t, mesh,     MeshCtx>>,
106    //     SceneCtx:    RefCast<'t, Value<'t, scene,    SceneCtx>>,
107    // {
108    //     fn as_refs_impl(&'t mut self) -> CtxRef<'t, geometry, material, mesh, scene> {
109    //         CtxRef {
110    //             geometry: RefCast::ref_cast(&mut self.geometry),
111    //             material: RefCast::ref_cast(&mut self.material),
112    //             mesh:     RefCast::ref_cast(&mut self.mesh),
113    //             scene:    RefCast::ref_cast(&mut self.scene),
114    //         }
115    //     }
116    // }
117    let impl_as_refs = quote! {
118        #[allow(non_camel_case_types)]
119        impl<'_t, #(#params,)*>
120        #lib::AsRefs<'_t, #ref_struct_ident<'_t, #(#params,)*>> for #struct_ident
121        where #bounds_params_access #(#field_types: #lib::RefCast<'_t, #field_values>,)* {
122            #[inline(always)]
123            fn as_refs_impl(& '_t mut self) -> #ref_struct_ident<'_t, #(#params,)*> {
124                #ref_struct_ident {
125                    #(#field_idents: #lib::RefCast::ref_cast(&mut self.#field_idents),)*
126                }
127            }
128        }
129    };
130
131    // Generates:
132    // impl Ctx {
133    //     pub fn as_ref_mut<'t>(&'t mut self) -> CtxRef<'t, RefMut, RefMut, RefMut, RefMut> {
134    //         CtxRef {
135    //             geometry: &mut self.geometry,
136    //             material: &mut self.material,
137    //             mesh:     &mut self.mesh,
138    //             scene:    &mut self.scene,
139    //         }
140    //     }
141    // }
142    let impl_as_ref_mut = {
143        let ref_muts = params.iter().map(|_| quote!{#lib::RefMut}).collect_vec();
144        quote! {
145            #[allow(non_camel_case_types)]
146            impl #struct_ident {
147                #[inline(always)]
148                pub fn as_ref_mut<'_t>(&'_t mut self) -> #ref_struct_ident<'_t, #(#ref_muts,)*> {
149                    #ref_struct_ident {
150                        #(#field_idents: &mut self.#field_idents,)*
151                    }
152                }
153            }
154        }
155    };
156
157    // Generates:
158    // impl<'t, geometry_target, material_target, mesh_target, scene_target,
159    //          geometry,        material,        mesh,        scene>
160    // Split<CtxRef<'t, geometry_target, material_target, mesh_target, scene_target>>
161    // for CtxRef<'t, geometry,        material,        mesh,        scene>
162    // where
163    //     geometry:        Access,
164    //     material:        Access,
165    //     mesh:            Access,
166    //     scene:           Access,
167    //     geometry_target: Access,
168    //     material_target: Access,
169    //     mesh_target:     Access,
170    //     scene_target:    Access,
171    //     geometry:        Acquire<geometry_target>,
172    //     material:        Acquire<material_target>,
173    //     mesh:            Acquire<mesh_target>,
174    //     scene:           Acquire<scene_target>,
175    // {
176    //     type Rest = CtxRef<'t,
177    //         Acquired<geometry, target_geometry>,
178    //         Acquired<material, target_material>,
179    //         Acquired<mesh,     target_mesh>,
180    //         Acquired<scene,    target_scene>,
181    //     >;
182    // }
183    let impl_split = {
184        let target_params = params.iter().map(|i| Ident::new(&format!("{i}_target"), i.span())).collect_vec();
185        let bounds_target_params_access = quote! { #(#target_params: #lib::Access,)* };
186        quote! {
187            #[allow(non_camel_case_types)]
188            impl<'_t, #(#params,)* #(#target_params,)*>
189            #lib::Split<#ref_struct_ident<'_t, #(#target_params,)*>> for #ref_struct_ident<'_t, #(#params,)*>
190            where
191                #bounds_params_access
192                #bounds_target_params_access
193                #(#params: #lib::Acquire<#target_params>,)*
194            {
195                type Rest = #ref_struct_ident<'_t, #(#lib::Acquired<#params, #target_params>,)*>;
196            }
197        }
198    };
199
200    // Generates:
201    // #[macro_export]
202    // macro_rules! _Ctx {
203    //     ($lt:lifetime $ ($ts:tt) *) => {
204    //         CtxImpl! { $lt [[None] [None] [None] [None]] [$($ts)*] }
205    //     };
206    //     ($($ts:tt)*) => {
207    //         CtxImpl! { '_ [[None] [None] [None] [None]] [, $ ($ts) *] }
208    //     };
209    // }
210    //
211    // #[macro_export]
212    // macro_rules! CtxImpl {
213    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [,* $($xs:tt)*]) => {
214    //         CtxImpl! { $lt [[Ref] [Ref] [Ref] [Ref]] [$ ($xs) *] }
215    //     };
216    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut * $ ($xs:tt) *]) => {
217    //         CtxImpl! { $lt [[RefMut] [RefMut] [RefMut] [RefMut]] [$ ($xs) *] }
218    //     };
219    //
220    //
221    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? geometry $ ($xs:tt) *]) => {
222    //         CtxImpl! { $lt [[Ref] $t1 $t2 $t3] [$ ($xs) *] }
223    //     };
224    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? material $ ($xs:tt) *]) => {
225    //         CtxImpl! { $lt [$t0 [Ref] $t2 $t3] [$ ($xs) *] }
226    //     };
227    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? mesh $ ($xs:tt) *]) => {
228    //         CtxImpl! { $lt [$t0 $t1 [Ref] $t3] [$ ($xs) *] }
229    //     };
230    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? scene $ ($xs:tt) *]) => {
231    //         CtxImpl! { $lt [$t0 $t1 $t2 [Ref]] [$ ($xs) *] }
232    //     };
233    //
234    //
235    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut geometry $ ($xs:tt) *]) => {
236    //         CtxImpl! { $lt [[RefMut] $t1 $t2 $t3] [$ ($xs) *] }
237    //     };
238    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut material $ ($xs:tt) *]) => {
239    //         CtxImpl! { $lt [$t0 [RefMut] $t2 $t3] [$ ($xs) *] }
240    //     };
241    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut mesh $ ($xs:tt) *]) => {
242    //         CtxImpl! { $lt [$t0 $t1 [RefMut] $t3] [$ ($xs) *] }
243    //     };
244    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut scene $ ($xs:tt) *]) => {
245    //         CtxImpl! { $lt [$t0 $t1 $t2 [RefMut]] [$ ($xs) *] }
246    //     };
247    //
248    //
249    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! geometry $ ($xs:tt) *]) => {
250    //         CtxImpl! { $lt [[None] $t1 $t2 $t3] [$ ($xs) *] }
251    //     };
252    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! material $ ($xs:tt) *]) => {
253    //         CtxImpl! { $lt [$t0 [None] $t2 $t3] [$ ($xs) *] }
254    //     };
255    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! mesh $ ($xs:tt) *]) => {
256    //         CtxImpl! { $lt [$t0 $t1 [None] $t3] [$ ($xs) *] }
257    //     };
258    //     ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! scene $ ($xs:tt) *]) => {
259    //         CtxImpl! { $lt [$t0 $t1 $t2 [None]] [$ ($xs) *] }
260    //     };
261    //
262    //
263    //     ($lt:lifetime [$ ([$ ($ts:tt) *]) *] [$ (,) *]) => {
264    //         CtxRef < $lt, $ ($ ($ts) *), * >
265    //     };
266    // }
267    // pub use _Ctx as Ctx;
268    let ref_macro = {
269        let q_none = quote! {[#lib::None]};
270        let q_ref = quote! {[#lib::Ref]};
271        let q_ref_mut = quote! {[#lib::RefMut]};
272        let all_none = field_idents.iter().map(|_| &q_none).collect_vec();
273        let all_ref = field_idents.iter().map(|_| &q_ref).collect_vec();
274        let all_ref_mut = field_idents.iter().map(|_| &q_ref_mut).collect_vec();
275        let ts_idents = field_idents.iter().enumerate().map(|(i, _)| Ident::new(&format!("t{i}"), Span::call_site())).collect_vec();
276        let ts = ts_idents.iter().map(|t| quote!($#t)).collect_vec();
277        let struct_ident2 = Ident::new(&format!("_{}", struct_ident), struct_ident.span());
278        let gen_patterns = |pattern: pm::TokenStream, access: &pm::TokenStream| {
279            field_idents.iter().enumerate().map(|(i, name)| {
280                let mut result = ts.iter().collect_vec();
281                result[i] = access;
282                quote! { (@ $lt:lifetime [#(#ts:tt)*] [, #pattern #name $($xs:tt)*]) => {
283                $crate::#struct_ident! {@ $lt [#(#result)*] [$($xs)*]} };
284            }
285            }).collect_vec()
286        };
287        let patterns_ref = gen_patterns(quote!{$(ref)?}, &q_ref);
288        let patterns_ref_mut = gen_patterns(quote!{mut}, &q_ref_mut);
289        let patterns_ref_none = gen_patterns(quote!{!}, &q_none);
290        quote! {
291            #[macro_export]
292            macro_rules! #struct_ident2 {
293                (@ $lt:lifetime [#(#ts:tt)*] [, ! * $($xs:tt)*]) => {
294                    $crate::#struct_ident! {@ $lt [#(#all_none)*] [$($xs)*]}
295                };
296                (@ $lt:lifetime [#(#ts:tt)*] [, * $($xs:tt)*]) => {
297                    $crate::#struct_ident! {@ $lt [#(#all_ref)*] [$($xs)*]}
298                };
299                (@ $lt:lifetime [#(#ts:tt)*] [, mut * $($xs:tt)*]) => {
300                    $crate::#struct_ident! {@ $lt [#(#all_ref_mut)*] [$($xs)*]}
301                };
302                #(#patterns_ref)*
303                #(#patterns_ref_mut)*
304                #(#patterns_ref_none)*
305                (@ $lt:lifetime [$([$($ts:tt)*])*] [$(,)*]) => { #module::#ref_struct_ident<$lt, $($($ts)*),*> };
306                (@ $($ts:tt)*) => { error };
307
308                ($lt:lifetime $($ts:tt)*) => {
309                    $crate::#struct_ident! {@ $lt [#(#all_none)*] [$($ts)*]}
310                };
311                ($($ts:tt)*) => {
312                    $crate::#struct_ident! {@ '_ [#(#all_none)*] [,$($ts)*]}
313                };
314            }
315
316            pub use #struct_ident2 as #struct_ident;
317        }
318    };
319
320    // Generates:
321    // impl<'t, geometry, material, mesh, scene>
322    // CtxRef<'t, geometry, material, mesh, scene>
323    // where geometry: Access, material: Access, mesh: Access, scene: Access {
324    //     pub fn extract_geometry(&mut self)
325    //         -> (&mut GeometryCtx, &mut <Self as Split<Ctx!['t, mut geometry]>>::Rest)
326    //     where geometry: Acquire<RefMut> {
327    //         let (a, b) = <Self as Split<Ctx! ['t, mut geometry]>>::split_impl(self);
328    //         (a.geometry, b)
329    //     }
330    //     ...
331    // }
332    let impl_extract_fields = {
333        let fns = field_idents.iter().zip(field_types.iter()).map(|(field, ty)| {
334            let name = Ident::new(&format!("extract_{field}"), field.span());
335            quote! {
336                #[inline(always)]
337                pub fn #name(&mut self) -> (&mut #ty, &mut <Self as #lib::Split<#struct_ident!['_t, mut #field]>>::Rest)
338                where #field: #lib::Acquire<#lib::RefMut> {
339                    let (a, b) = <Self as #lib::Split<#struct_ident!['_t, mut #field]>>::split_impl(self);
340                    (a.#field, b)
341                }
342            }
343        }).collect_vec();
344        quote! {
345            #[allow(non_camel_case_types)]
346            impl<'_t, #(#params,)*> #ref_struct_ident<'_t, #(#params,)*>
347            where #bounds_params_access {
348                #(#fns)*
349            }
350        }
351    };
352
353    let out = quote! {
354        #ref_struct
355        #impl_as_refs
356        #impl_as_ref_mut
357        #impl_split
358        #ref_macro
359        #impl_extract_fields
360    };
361
362    // println!(">>> {}", out);
363    TokenStream::from(out)
364}