Skip to main content

vortex_array_macros/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Proc macros for `vortex-array`.
5
6use proc_macro::TokenStream;
7use quote::format_ident;
8use quote::quote;
9use syn::Field;
10use syn::Fields;
11use syn::Ident;
12use syn::ItemStruct;
13use syn::Path;
14use syn::Type;
15use syn::Visibility;
16use syn::parse_macro_input;
17use syn::spanned::Spanned;
18
19/// Generate slot index constants, a borrowed view struct, and a typed ext trait
20/// from a slot struct definition.
21///
22/// Fields must be `ArrayRef` (required slot) or `Option<ArrayRef>` (optional slot).
23/// Field declaration order determines slot indices.
24///
25/// # Example
26///
27/// ```ignore
28/// #[array_slots(Patched)]
29/// pub struct PatchedSlots {
30///     pub inner: ArrayRef,
31///     pub lane_offsets: ArrayRef,
32///     pub patch_indices: ArrayRef,
33///     pub patch_values: ArrayRef,
34/// }
35/// ```
36///
37/// # Generated output
38///
39/// Given the above, the macro generates:
40///
41/// ```ignore
42/// // --- The original struct is preserved as-is ---
43/// pub struct PatchedSlots { ... }
44///
45/// // --- Slot index constants and conversion methods on the struct ---
46/// impl PatchedSlots {
47///     pub const INNER: usize = 0;
48///     pub const LANE_OFFSETS: usize = 1;
49///     pub const PATCH_INDICES: usize = 2;
50///     pub const PATCH_VALUES: usize = 3;
51///     pub const COUNT: usize = 4;
52///     pub const NAMES: [&'static str; 4] = ["inner", "lane_offsets", "patch_indices", "patch_values"];
53///
54///     /// Take ownership of slots from a `Vec<Option<ArrayRef>>`.
55///     pub fn from_slots(slots: Vec<Option<ArrayRef>>) -> Self { ... }
56///
57///     /// Convert back into storage order.
58///     pub fn into_slots(self) -> Vec<Option<ArrayRef>> { ... }
59/// }
60///
61/// // --- Borrowed view with &ArrayRef / Option<&ArrayRef> fields ---
62/// pub struct PatchedSlotsView<'a> {
63///     pub inner: &'a ArrayRef,
64///     pub lane_offsets: &'a ArrayRef,
65///     pub patch_indices: &'a ArrayRef,
66///     pub patch_values: &'a ArrayRef,
67/// }
68///
69/// impl<'a> PatchedSlotsView<'a> {
70///     pub fn from_slots(slots: &'a [Option<ArrayRef>]) -> Self { ... }
71///     pub fn to_owned(&self) -> PatchedSlots { ... }
72/// }
73///
74/// // --- Ext trait with per-field accessors + slots_view() ---
75/// pub trait PatchedArraySlotsExt: TypedArrayRef<Patched> {
76///     fn inner(&self) -> &ArrayRef { ... }         // indexes slots directly
77///     fn lane_offsets(&self) -> &ArrayRef { ... }
78///     fn patch_indices(&self) -> &ArrayRef { ... }
79///     fn patch_values(&self) -> &ArrayRef { ... }
80///     fn slots_view(&self) -> PatchedSlotsView<'_> { ... }
81/// }
82///
83/// impl<T: TypedArrayRef<Patched>> PatchedArraySlotsExt for T {}
84/// ```
85///
86/// # Required vs optional slots
87///
88/// - `ArrayRef` — the slot must be present. `from_slots()` panics if `None`.
89///   The ext trait accessor returns `&ArrayRef`. The view field is `&'a ArrayRef`.
90///
91/// - `Option<ArrayRef>` — the slot may be absent. `from_slots()` preserves `None`.
92///   The ext trait accessor returns `Option<&ArrayRef>`. The view field is
93///   `Option<&'a ArrayRef>`.
94///
95/// The underlying storage is always `Vec<Option<ArrayRef>>` — the field type only
96/// controls whether the macro inserts a `.vortex_expect()` unwrap or not.
97#[proc_macro_attribute]
98pub fn array_slots(attr: TokenStream, item: TokenStream) -> TokenStream {
99    let encoding = parse_macro_input!(attr as Path);
100    let item_struct = parse_macro_input!(item as ItemStruct);
101
102    match expand_array_slots(encoding, item_struct) {
103        Ok(tokens) => tokens.into(),
104        Err(err) => err.to_compile_error().into(),
105    }
106}
107
108fn expand_array_slots(
109    encoding: Path,
110    item_struct: ItemStruct,
111) -> syn::Result<proc_macro2::TokenStream> {
112    if !item_struct.generics.params.is_empty() || item_struct.generics.where_clause.is_some() {
113        return Err(syn::Error::new(
114            item_struct.generics.span(),
115            "#[array_slots] does not support generic slot structs",
116        ));
117    }
118
119    let fields = match &item_struct.fields {
120        Fields::Named(fields) => &fields.named,
121        _ => {
122            return Err(syn::Error::new(
123                item_struct.span(),
124                "#[array_slots] requires a struct with named fields",
125            ));
126        }
127    };
128
129    let encoding_ident = encoding
130        .segments
131        .last()
132        .map(|segment| &segment.ident)
133        .ok_or_else(|| syn::Error::new(encoding.span(), "missing encoding type"))?;
134
135    let struct_ident = &item_struct.ident;
136    let struct_vis = &item_struct.vis;
137    let view_ident = format_ident!("{}View", ident_name(struct_ident));
138    let ext_ident = format_ident!("{}ArraySlotsExt", ident_name(encoding_ident));
139
140    let field_specs = fields
141        .iter()
142        .enumerate()
143        .map(|(index, field)| SlotField::new(field, index, struct_ident))
144        .collect::<syn::Result<Vec<_>>>()?;
145
146    let idx_consts = field_specs.iter().map(SlotField::idx_const);
147    let view_fields = field_specs.iter().map(SlotField::view_field);
148    let view_from_slots = field_specs.iter().map(SlotField::view_from_slots);
149    let view_to_owned = field_specs.iter().map(SlotField::view_to_owned);
150    let owned_from_slots = field_specs.iter().map(SlotField::owned_from_slots);
151    let into_slots = field_specs.iter().map(SlotField::storage_slot);
152    let ext_methods = field_specs.iter().map(SlotField::ext_method);
153    let slot_names = field_specs.iter().map(|field| field.slot_name.as_str());
154    let slot_count = field_specs.len();
155
156    Ok(quote! {
157        #item_struct
158
159        impl #struct_ident {
160            #(#idx_consts)*
161
162            #[doc = "Total number of slots."]
163            pub const COUNT: usize = #slot_count;
164
165            #[doc = "Slot names in storage order."]
166            pub const NAMES: [&'static str; #slot_count] = [#(#slot_names),*];
167
168            #[doc = "Convert owned slot storage into an owned slot struct."]
169            pub fn from_slots(mut slots: Vec<Option<::vortex_array::ArrayRef>>) -> Self {
170                Self {
171                    #(#owned_from_slots,)*
172                }
173            }
174
175            #[doc = "Convert this slot struct into storage order."]
176            pub fn into_slots(self) -> Vec<Option<::vortex_array::ArrayRef>> {
177                vec![#(#into_slots),*]
178            }
179        }
180
181        #[derive(Clone, Copy, Debug)]
182        #[doc = concat!("Borrowed view of `", stringify!(#struct_ident), "`.")]
183        #struct_vis struct #view_ident<'a> {
184            #(#view_fields,)*
185        }
186
187        impl<'a> #view_ident<'a> {
188            #[doc = "Borrow a slot slice as a typed view."]
189            pub fn from_slots(slots: &'a [Option<::vortex_array::ArrayRef>]) -> Self {
190                Self {
191                    #(#view_from_slots,)*
192                }
193            }
194
195            #[doc = "Clone all referenced slots into an owned slot struct."]
196            pub fn to_owned(&self) -> #struct_ident {
197                #struct_ident {
198                    #(#view_to_owned,)*
199                }
200            }
201        }
202
203        #[doc = concat!("Typed array accessors for `", stringify!(#encoding_ident), "`.")]
204        #struct_vis trait #ext_ident: ::vortex_array::TypedArrayRef<#encoding> {
205            #(#ext_methods)*
206
207            #[doc = "Returns a borrowed view of all slots."]
208            fn slots_view(&self) -> #view_ident<'_> {
209                #view_ident::from_slots(self.as_ref().slots())
210            }
211        }
212
213        impl<T: ::vortex_array::TypedArrayRef<#encoding>> #ext_ident for T {}
214    })
215}
216
217struct SlotField {
218    field_ident: Ident,
219    field_vis: Visibility,
220    const_ident: Ident,
221    slot_name: String,
222    slot_type: SlotFieldType,
223    index: usize,
224    expect_message: syn::LitStr,
225    struct_ident: Ident,
226}
227
228impl SlotField {
229    fn new(field: &Field, index: usize, struct_ident: &Ident) -> syn::Result<Self> {
230        let field_ident = field
231            .ident
232            .clone()
233            .ok_or_else(|| syn::Error::new(field.span(), "slot fields must be named"))?;
234        let field_name = ident_name(&field_ident);
235        let const_ident = format_ident!("{}", to_screaming_snake_case(&field_name));
236        let slot_type = SlotFieldType::from_syn_type(&field.ty)?;
237        let expect_message = syn::LitStr::new(
238            &format!("{} {} slot", ident_name(struct_ident), field_name),
239            field.span(),
240        );
241
242        Ok(Self {
243            field_ident,
244            field_vis: field.vis.clone(),
245            const_ident,
246            slot_name: field_name,
247            slot_type,
248            index,
249            expect_message,
250            struct_ident: struct_ident.clone(),
251        })
252    }
253
254    fn idx_const(&self) -> proc_macro2::TokenStream {
255        let const_ident = &self.const_ident;
256        let index = self.index;
257        let slot_name = &self.slot_name;
258
259        quote! {
260            #[doc = concat!("Slot index for `", #slot_name, "`.")]
261            pub const #const_ident: usize = #index;
262        }
263    }
264
265    fn view_field(&self) -> proc_macro2::TokenStream {
266        let field_ident = &self.field_ident;
267        let field_vis = &self.field_vis;
268        let ty = self.slot_type.view_field_ty();
269
270        quote! {
271            #field_vis #field_ident: #ty
272        }
273    }
274
275    fn view_from_slots(&self) -> proc_macro2::TokenStream {
276        let field_ident = &self.field_ident;
277        let struct_ident = &self.struct_ident;
278        let const_ident = &self.const_ident;
279        let expect_message = &self.expect_message;
280
281        match self.slot_type {
282            SlotFieldType::Required => quote! {
283                #field_ident: ::vortex_error::VortexExpect::vortex_expect(
284                    slots[#struct_ident::#const_ident].as_ref(),
285                    #expect_message,
286                )
287            },
288            SlotFieldType::Optional => quote! {
289                #field_ident: slots[#struct_ident::#const_ident].as_ref()
290            },
291        }
292    }
293
294    fn view_to_owned(&self) -> proc_macro2::TokenStream {
295        let field_ident = &self.field_ident;
296
297        match self.slot_type {
298            SlotFieldType::Required => quote! {
299                #field_ident: ::std::clone::Clone::clone(self.#field_ident)
300            },
301            SlotFieldType::Optional => quote! {
302                #field_ident: self.#field_ident.cloned()
303            },
304        }
305    }
306
307    fn owned_from_slots(&self) -> proc_macro2::TokenStream {
308        let field_ident = &self.field_ident;
309        let struct_ident = &self.struct_ident;
310        let const_ident = &self.const_ident;
311        let expect_message = &self.expect_message;
312
313        match self.slot_type {
314            SlotFieldType::Required => quote! {
315                #field_ident: ::vortex_error::VortexExpect::vortex_expect(
316                    slots[#struct_ident::#const_ident].take(),
317                    #expect_message,
318                )
319            },
320            SlotFieldType::Optional => quote! {
321                #field_ident: slots[#struct_ident::#const_ident].take()
322            },
323        }
324    }
325
326    fn storage_slot(&self) -> proc_macro2::TokenStream {
327        let field_ident = &self.field_ident;
328
329        match self.slot_type {
330            SlotFieldType::Required => quote! {
331                Some(self.#field_ident)
332            },
333            SlotFieldType::Optional => quote! {
334                self.#field_ident
335            },
336        }
337    }
338
339    fn ext_method(&self) -> proc_macro2::TokenStream {
340        let field_ident = &self.field_ident;
341        let struct_ident = &self.struct_ident;
342        let const_ident = &self.const_ident;
343        let expect_message = &self.expect_message;
344
345        match self.slot_type {
346            SlotFieldType::Required => quote! {
347                #[inline]
348                fn #field_ident(&self) -> &::vortex_array::ArrayRef {
349                    ::vortex_error::VortexExpect::vortex_expect(
350                        self.as_ref().slots()[#struct_ident::#const_ident].as_ref(),
351                        #expect_message,
352                    )
353                }
354            },
355            SlotFieldType::Optional => quote! {
356                #[inline]
357                fn #field_ident(&self) -> Option<&::vortex_array::ArrayRef> {
358                    self.as_ref().slots()[#struct_ident::#const_ident].as_ref()
359                }
360            },
361        }
362    }
363}
364
365#[derive(Clone, Copy)]
366enum SlotFieldType {
367    Required,
368    Optional,
369}
370
371impl SlotFieldType {
372    fn from_syn_type(ty: &Type) -> syn::Result<Self> {
373        if is_array_ref_type(ty) {
374            return Ok(Self::Required);
375        }
376
377        if let Some(inner_ty) = option_inner_type(ty)
378            && is_array_ref_type(inner_ty)
379        {
380            return Ok(Self::Optional);
381        }
382
383        Err(syn::Error::new(
384            ty.span(),
385            "#[array_slots] fields must be ArrayRef or Option<ArrayRef>",
386        ))
387    }
388
389    fn view_field_ty(self) -> proc_macro2::TokenStream {
390        match self {
391            Self::Required => quote! { &'a ::vortex_array::ArrayRef },
392            Self::Optional => quote! { Option<&'a ::vortex_array::ArrayRef> },
393        }
394    }
395}
396
397fn is_array_ref_type(ty: &Type) -> bool {
398    matches!(
399        ty,
400        Type::Path(type_path)
401            if type_path.qself.is_none()
402                && type_path
403                    .path
404                    .segments
405                    .last()
406                    .is_some_and(|segment| segment.ident == "ArrayRef")
407    )
408}
409
410fn option_inner_type(ty: &Type) -> Option<&Type> {
411    let Type::Path(type_path) = ty else {
412        return None;
413    };
414    let segment = type_path.path.segments.last()?;
415    if segment.ident != "Option" {
416        return None;
417    }
418
419    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
420        return None;
421    };
422
423    match args.args.first()? {
424        syn::GenericArgument::Type(inner_ty) => Some(inner_ty),
425        _ => None,
426    }
427}
428
429fn ident_name(ident: &Ident) -> String {
430    ident.to_string().trim_start_matches("r#").to_owned()
431}
432
433fn to_screaming_snake_case(name: &str) -> String {
434    let mut result = String::with_capacity(name.len());
435    let mut prev_is_lower_or_digit = false;
436
437    for ch in name.chars() {
438        if ch.is_ascii_uppercase() && prev_is_lower_or_digit {
439            result.push('_');
440        }
441        result.push(ch.to_ascii_uppercase());
442        prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
443    }
444
445    result
446}