struct_split_macro/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Ident, Data, Fields, Path};
4use itertools::Itertools;
5use proc_macro2::{Span};
6use proc_macro2 as pm;
7
8
9// =============
10// === Utils ===
11// =============
12
13/// Get the current crate name;
14fn crate_name() -> Ident {
15 let macro_lib = env!("CARGO_PKG_NAME");
16 let suffix = "-macro";
17 if !macro_lib.ends_with(suffix) { panic!("Internal error.") }
18 let crate_name = ¯o_lib[..macro_lib.len() - suffix.len()].replace('-',"_");
19 Ident::new(crate_name, Span::call_site())
20}
21
22/// Extract the module macro attribute.
23fn extract_module_attr(input: &DeriveInput) -> Path {
24 let mut module: Option<Path> = None;
25 for attr in &input.attrs {
26 if attr.path().is_ident("module") {
27 let tokens = attr.meta.require_list().unwrap().tokens.clone();
28 if let Ok(path) = syn::parse2::<Path>(tokens) {
29 module = Some(path);
30 }
31 }
32 }
33 module.expect("The 'module' attribute is required.")
34}
35
36
37// =============
38// === Macro ===
39// =============
40
41/// Derive impl. Comments in the code show expansion of the following example struct:
42/// ```ignore
43/// pub struct Ctx {
44/// geometry: GeometryCtx,
45/// material: MaterialCtx,
46/// mesh: MeshCtx,
47/// scene: SceneCtx,
48/// }
49/// ```
50#[proc_macro_derive(Split, attributes(module))]
51pub fn split_derive(input: TokenStream) -> TokenStream {
52 let lib = crate_name();
53 let input = parse_macro_input!(input as DeriveInput);
54 let module = extract_module_attr(&input);
55
56 let struct_ident = input.ident;
57 let ref_struct_ident = Ident::new(&format!("{struct_ident}Ref"), struct_ident.span());
58
59 let fields = if let Data::Struct(data) = &input.data {
60 if let Fields::Named(fields) = &data.fields {
61 fields.named.iter().collect::<Vec<_>>()
62 } else {
63 Vec::new()
64 }
65 } else {
66 Vec::new()
67 };
68
69 let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect_vec();
70 let field_types = fields.iter().map(|f| &f.ty).collect_vec();
71 let params = field_idents.iter().map(|i| Ident::new(&i.to_string(), i.span())).collect_vec();
72 let bounds_params_access = quote! { #(#params: #lib::Access,)* };
73 let field_values = field_types.iter().zip(params.iter()).map(|(field_ty, param)| {
74 quote! { #lib::Value<'_t, #param, #field_ty> }
75 }).collect_vec();
76
77 // Generates:
78 // #[repr(C)]
79 // pub struct CtxRef<'t, geometry, material, mesh, scene> {
80 // geometry: Value<'t, geometry, GeometryCtx>,
81 // material: Value<'t, material, MaterialCtx>,
82 // mesh: Value<'t, mesh, MeshCtx>,
83 // scene: Value<'t, scene, SceneCtx>,
84 // }
85 let ref_struct = quote! {
86 #[derive(Debug)]
87 #[repr(C)]
88 #[allow(non_camel_case_types)]
89 pub struct #ref_struct_ident<'_t, #(#params),*>
90 where #bounds_params_access {
91 #(pub #field_idents : #field_values),*
92 }
93 };
94
95 // Generates:
96 // impl<'t, geometry, material, mesh, scene>
97 // AsRefs<'t, CtxRef<'t, geometry, material, mesh, scene>> for Ctx
98 // where
99 // geometry: Access,
100 // material: Access,
101 // mesh: Access,
102 // scene: Access,
103 // GeometryCtx: RefCast<'t, Value<'t, geometry, GeometryCtx>>,
104 // MaterialCtx: RefCast<'t, Value<'t, material, MaterialCtx>>,
105 // MeshCtx: RefCast<'t, Value<'t, mesh, MeshCtx>>,
106 // SceneCtx: RefCast<'t, Value<'t, scene, SceneCtx>>,
107 // {
108 // fn as_refs_impl(&'t mut self) -> CtxRef<'t, geometry, material, mesh, scene> {
109 // CtxRef {
110 // geometry: RefCast::ref_cast(&mut self.geometry),
111 // material: RefCast::ref_cast(&mut self.material),
112 // mesh: RefCast::ref_cast(&mut self.mesh),
113 // scene: RefCast::ref_cast(&mut self.scene),
114 // }
115 // }
116 // }
117 let impl_as_refs = quote! {
118 #[allow(non_camel_case_types)]
119 impl<'_t, #(#params,)*>
120 #lib::AsRefs<'_t, #ref_struct_ident<'_t, #(#params,)*>> for #struct_ident
121 where #bounds_params_access #(#field_types: #lib::RefCast<'_t, #field_values>,)* {
122 #[inline(always)]
123 fn as_refs_impl(& '_t mut self) -> #ref_struct_ident<'_t, #(#params,)*> {
124 #ref_struct_ident {
125 #(#field_idents: #lib::RefCast::ref_cast(&mut self.#field_idents),)*
126 }
127 }
128 }
129 };
130
131 // Generates:
132 // impl Ctx {
133 // pub fn as_ref_mut<'t>(&'t mut self) -> CtxRef<'t, RefMut, RefMut, RefMut, RefMut> {
134 // CtxRef {
135 // geometry: &mut self.geometry,
136 // material: &mut self.material,
137 // mesh: &mut self.mesh,
138 // scene: &mut self.scene,
139 // }
140 // }
141 // }
142 let impl_as_ref_mut = {
143 let ref_muts = params.iter().map(|_| quote!{#lib::RefMut}).collect_vec();
144 quote! {
145 #[allow(non_camel_case_types)]
146 impl #struct_ident {
147 #[inline(always)]
148 pub fn as_ref_mut<'_t>(&'_t mut self) -> #ref_struct_ident<'_t, #(#ref_muts,)*> {
149 #ref_struct_ident {
150 #(#field_idents: &mut self.#field_idents,)*
151 }
152 }
153 }
154 }
155 };
156
157 // Generates:
158 // impl<'t, geometry_target, material_target, mesh_target, scene_target,
159 // geometry, material, mesh, scene>
160 // Split<CtxRef<'t, geometry_target, material_target, mesh_target, scene_target>>
161 // for CtxRef<'t, geometry, material, mesh, scene>
162 // where
163 // geometry: Access,
164 // material: Access,
165 // mesh: Access,
166 // scene: Access,
167 // geometry_target: Access,
168 // material_target: Access,
169 // mesh_target: Access,
170 // scene_target: Access,
171 // geometry: Acquire<geometry_target>,
172 // material: Acquire<material_target>,
173 // mesh: Acquire<mesh_target>,
174 // scene: Acquire<scene_target>,
175 // {
176 // type Rest = CtxRef<'t,
177 // Acquired<geometry, target_geometry>,
178 // Acquired<material, target_material>,
179 // Acquired<mesh, target_mesh>,
180 // Acquired<scene, target_scene>,
181 // >;
182 // }
183 let impl_split = {
184 let target_params = params.iter().map(|i| Ident::new(&format!("{i}_target"), i.span())).collect_vec();
185 let bounds_target_params_access = quote! { #(#target_params: #lib::Access,)* };
186 quote! {
187 #[allow(non_camel_case_types)]
188 impl<'_t, #(#params,)* #(#target_params,)*>
189 #lib::Split<#ref_struct_ident<'_t, #(#target_params,)*>> for #ref_struct_ident<'_t, #(#params,)*>
190 where
191 #bounds_params_access
192 #bounds_target_params_access
193 #(#params: #lib::Acquire<#target_params>,)*
194 {
195 type Rest = #ref_struct_ident<'_t, #(#lib::Acquired<#params, #target_params>,)*>;
196 }
197 }
198 };
199
200 // Generates:
201 // #[macro_export]
202 // macro_rules! _Ctx {
203 // ($lt:lifetime $ ($ts:tt) *) => {
204 // CtxImpl! { $lt [[None] [None] [None] [None]] [$($ts)*] }
205 // };
206 // ($($ts:tt)*) => {
207 // CtxImpl! { '_ [[None] [None] [None] [None]] [, $ ($ts) *] }
208 // };
209 // }
210 //
211 // #[macro_export]
212 // macro_rules! CtxImpl {
213 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [,* $($xs:tt)*]) => {
214 // CtxImpl! { $lt [[Ref] [Ref] [Ref] [Ref]] [$ ($xs) *] }
215 // };
216 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut * $ ($xs:tt) *]) => {
217 // CtxImpl! { $lt [[RefMut] [RefMut] [RefMut] [RefMut]] [$ ($xs) *] }
218 // };
219 //
220 //
221 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? geometry $ ($xs:tt) *]) => {
222 // CtxImpl! { $lt [[Ref] $t1 $t2 $t3] [$ ($xs) *] }
223 // };
224 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? material $ ($xs:tt) *]) => {
225 // CtxImpl! { $lt [$t0 [Ref] $t2 $t3] [$ ($xs) *] }
226 // };
227 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? mesh $ ($xs:tt) *]) => {
228 // CtxImpl! { $lt [$t0 $t1 [Ref] $t3] [$ ($xs) *] }
229 // };
230 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, $(ref)? scene $ ($xs:tt) *]) => {
231 // CtxImpl! { $lt [$t0 $t1 $t2 [Ref]] [$ ($xs) *] }
232 // };
233 //
234 //
235 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut geometry $ ($xs:tt) *]) => {
236 // CtxImpl! { $lt [[RefMut] $t1 $t2 $t3] [$ ($xs) *] }
237 // };
238 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut material $ ($xs:tt) *]) => {
239 // CtxImpl! { $lt [$t0 [RefMut] $t2 $t3] [$ ($xs) *] }
240 // };
241 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut mesh $ ($xs:tt) *]) => {
242 // CtxImpl! { $lt [$t0 $t1 [RefMut] $t3] [$ ($xs) *] }
243 // };
244 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, mut scene $ ($xs:tt) *]) => {
245 // CtxImpl! { $lt [$t0 $t1 $t2 [RefMut]] [$ ($xs) *] }
246 // };
247 //
248 //
249 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! geometry $ ($xs:tt) *]) => {
250 // CtxImpl! { $lt [[None] $t1 $t2 $t3] [$ ($xs) *] }
251 // };
252 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! material $ ($xs:tt) *]) => {
253 // CtxImpl! { $lt [$t0 [None] $t2 $t3] [$ ($xs) *] }
254 // };
255 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! mesh $ ($xs:tt) *]) => {
256 // CtxImpl! { $lt [$t0 $t1 [None] $t3] [$ ($xs) *] }
257 // };
258 // ($lt:lifetime [$t0:tt $t1:tt $t2:tt $t3:tt] [, ! scene $ ($xs:tt) *]) => {
259 // CtxImpl! { $lt [$t0 $t1 $t2 [None]] [$ ($xs) *] }
260 // };
261 //
262 //
263 // ($lt:lifetime [$ ([$ ($ts:tt) *]) *] [$ (,) *]) => {
264 // CtxRef < $lt, $ ($ ($ts) *), * >
265 // };
266 // }
267 // pub use _Ctx as Ctx;
268 let ref_macro = {
269 let q_none = quote! {[#lib::None]};
270 let q_ref = quote! {[#lib::Ref]};
271 let q_ref_mut = quote! {[#lib::RefMut]};
272 let all_none = field_idents.iter().map(|_| &q_none).collect_vec();
273 let all_ref = field_idents.iter().map(|_| &q_ref).collect_vec();
274 let all_ref_mut = field_idents.iter().map(|_| &q_ref_mut).collect_vec();
275 let ts_idents = field_idents.iter().enumerate().map(|(i, _)| Ident::new(&format!("t{i}"), Span::call_site())).collect_vec();
276 let ts = ts_idents.iter().map(|t| quote!($#t)).collect_vec();
277 let struct_ident2 = Ident::new(&format!("_{}", struct_ident), struct_ident.span());
278 let gen_patterns = |pattern: pm::TokenStream, access: &pm::TokenStream| {
279 field_idents.iter().enumerate().map(|(i, name)| {
280 let mut result = ts.iter().collect_vec();
281 result[i] = access;
282 quote! { (@ $lt:lifetime [#(#ts:tt)*] [, #pattern #name $($xs:tt)*]) => {
283 $crate::#struct_ident! {@ $lt [#(#result)*] [$($xs)*]} };
284 }
285 }).collect_vec()
286 };
287 let patterns_ref = gen_patterns(quote!{$(ref)?}, &q_ref);
288 let patterns_ref_mut = gen_patterns(quote!{mut}, &q_ref_mut);
289 let patterns_ref_none = gen_patterns(quote!{!}, &q_none);
290 quote! {
291 #[macro_export]
292 macro_rules! #struct_ident2 {
293 (@ $lt:lifetime [#(#ts:tt)*] [, ! * $($xs:tt)*]) => {
294 $crate::#struct_ident! {@ $lt [#(#all_none)*] [$($xs)*]}
295 };
296 (@ $lt:lifetime [#(#ts:tt)*] [, * $($xs:tt)*]) => {
297 $crate::#struct_ident! {@ $lt [#(#all_ref)*] [$($xs)*]}
298 };
299 (@ $lt:lifetime [#(#ts:tt)*] [, mut * $($xs:tt)*]) => {
300 $crate::#struct_ident! {@ $lt [#(#all_ref_mut)*] [$($xs)*]}
301 };
302 #(#patterns_ref)*
303 #(#patterns_ref_mut)*
304 #(#patterns_ref_none)*
305 (@ $lt:lifetime [$([$($ts:tt)*])*] [$(,)*]) => { #module::#ref_struct_ident<$lt, $($($ts)*),*> };
306 (@ $($ts:tt)*) => { error };
307
308 ($lt:lifetime $($ts:tt)*) => {
309 $crate::#struct_ident! {@ $lt [#(#all_none)*] [$($ts)*]}
310 };
311 ($($ts:tt)*) => {
312 $crate::#struct_ident! {@ '_ [#(#all_none)*] [,$($ts)*]}
313 };
314 }
315
316 pub use #struct_ident2 as #struct_ident;
317 }
318 };
319
320 // Generates:
321 // impl<'t, geometry, material, mesh, scene>
322 // CtxRef<'t, geometry, material, mesh, scene>
323 // where geometry: Access, material: Access, mesh: Access, scene: Access {
324 // pub fn extract_geometry(&mut self)
325 // -> (&mut GeometryCtx, &mut <Self as Split<Ctx!['t, mut geometry]>>::Rest)
326 // where geometry: Acquire<RefMut> {
327 // let (a, b) = <Self as Split<Ctx! ['t, mut geometry]>>::split_impl(self);
328 // (a.geometry, b)
329 // }
330 // ...
331 // }
332 let impl_extract_fields = {
333 let fns = field_idents.iter().zip(field_types.iter()).map(|(field, ty)| {
334 let name = Ident::new(&format!("extract_{field}"), field.span());
335 quote! {
336 #[inline(always)]
337 pub fn #name(&mut self) -> (&mut #ty, &mut <Self as #lib::Split<#struct_ident!['_t, mut #field]>>::Rest)
338 where #field: #lib::Acquire<#lib::RefMut> {
339 let (a, b) = <Self as #lib::Split<#struct_ident!['_t, mut #field]>>::split_impl(self);
340 (a.#field, b)
341 }
342 }
343 }).collect_vec();
344 quote! {
345 #[allow(non_camel_case_types)]
346 impl<'_t, #(#params,)*> #ref_struct_ident<'_t, #(#params,)*>
347 where #bounds_params_access {
348 #(#fns)*
349 }
350 }
351 };
352
353 let out = quote! {
354 #ref_struct
355 #impl_as_refs
356 #impl_as_ref_mut
357 #impl_split
358 #ref_macro
359 #impl_extract_fields
360 };
361
362 // println!(">>> {}", out);
363 TokenStream::from(out)
364}