vortex_array_macros/
lib.rs1use proc_macro::TokenStream;
7use quote::format_ident;
8use quote::quote;
9use syn::Field;
10use syn::Fields;
11use syn::Ident;
12use syn::ItemStruct;
13use syn::Path;
14use syn::Type;
15use syn::Visibility;
16use syn::parse_macro_input;
17use syn::spanned::Spanned;
18
19#[proc_macro_attribute]
98pub fn array_slots(attr: TokenStream, item: TokenStream) -> TokenStream {
99 let encoding = parse_macro_input!(attr as Path);
100 let item_struct = parse_macro_input!(item as ItemStruct);
101
102 match expand_array_slots(encoding, item_struct) {
103 Ok(tokens) => tokens.into(),
104 Err(err) => err.to_compile_error().into(),
105 }
106}
107
108fn expand_array_slots(
109 encoding: Path,
110 item_struct: ItemStruct,
111) -> syn::Result<proc_macro2::TokenStream> {
112 if !item_struct.generics.params.is_empty() || item_struct.generics.where_clause.is_some() {
113 return Err(syn::Error::new(
114 item_struct.generics.span(),
115 "#[array_slots] does not support generic slot structs",
116 ));
117 }
118
119 let fields = match &item_struct.fields {
120 Fields::Named(fields) => &fields.named,
121 _ => {
122 return Err(syn::Error::new(
123 item_struct.span(),
124 "#[array_slots] requires a struct with named fields",
125 ));
126 }
127 };
128
129 let encoding_ident = encoding
130 .segments
131 .last()
132 .map(|segment| &segment.ident)
133 .ok_or_else(|| syn::Error::new(encoding.span(), "missing encoding type"))?;
134
135 let struct_ident = &item_struct.ident;
136 let struct_vis = &item_struct.vis;
137 let view_ident = format_ident!("{}View", ident_name(struct_ident));
138 let ext_ident = format_ident!("{}ArraySlotsExt", ident_name(encoding_ident));
139
140 let field_specs = fields
141 .iter()
142 .enumerate()
143 .map(|(index, field)| SlotField::new(field, index, struct_ident))
144 .collect::<syn::Result<Vec<_>>>()?;
145
146 let idx_consts = field_specs.iter().map(SlotField::idx_const);
147 let view_fields = field_specs.iter().map(SlotField::view_field);
148 let view_from_slots = field_specs.iter().map(SlotField::view_from_slots);
149 let view_to_owned = field_specs.iter().map(SlotField::view_to_owned);
150 let owned_from_slots = field_specs.iter().map(SlotField::owned_from_slots);
151 let into_slots = field_specs.iter().map(SlotField::storage_slot);
152 let ext_methods = field_specs.iter().map(SlotField::ext_method);
153 let slot_names = field_specs.iter().map(|field| field.slot_name.as_str());
154 let slot_count = field_specs.len();
155
156 Ok(quote! {
157 #item_struct
158
159 impl #struct_ident {
160 #(#idx_consts)*
161
162 #[doc = "Total number of slots."]
163 pub const COUNT: usize = #slot_count;
164
165 #[doc = "Slot names in storage order."]
166 pub const NAMES: [&'static str; #slot_count] = [#(#slot_names),*];
167
168 #[doc = "Convert owned slot storage into an owned slot struct."]
169 pub fn from_slots(mut slots: Vec<Option<::vortex_array::ArrayRef>>) -> Self {
170 Self {
171 #(#owned_from_slots,)*
172 }
173 }
174
175 #[doc = "Convert this slot struct into storage order."]
176 pub fn into_slots(self) -> Vec<Option<::vortex_array::ArrayRef>> {
177 vec![#(#into_slots),*]
178 }
179 }
180
181 #[derive(Clone, Copy, Debug)]
182 #[doc = concat!("Borrowed view of `", stringify!(#struct_ident), "`.")]
183 #struct_vis struct #view_ident<'a> {
184 #(#view_fields,)*
185 }
186
187 impl<'a> #view_ident<'a> {
188 #[doc = "Borrow a slot slice as a typed view."]
189 pub fn from_slots(slots: &'a [Option<::vortex_array::ArrayRef>]) -> Self {
190 Self {
191 #(#view_from_slots,)*
192 }
193 }
194
195 #[doc = "Clone all referenced slots into an owned slot struct."]
196 pub fn to_owned(&self) -> #struct_ident {
197 #struct_ident {
198 #(#view_to_owned,)*
199 }
200 }
201 }
202
203 #[doc = concat!("Typed array accessors for `", stringify!(#encoding_ident), "`.")]
204 #struct_vis trait #ext_ident: ::vortex_array::TypedArrayRef<#encoding> {
205 #(#ext_methods)*
206
207 #[doc = "Returns a borrowed view of all slots."]
208 fn slots_view(&self) -> #view_ident<'_> {
209 #view_ident::from_slots(self.as_ref().slots())
210 }
211 }
212
213 impl<T: ::vortex_array::TypedArrayRef<#encoding>> #ext_ident for T {}
214 })
215}
216
217struct SlotField {
218 field_ident: Ident,
219 field_vis: Visibility,
220 const_ident: Ident,
221 slot_name: String,
222 slot_type: SlotFieldType,
223 index: usize,
224 expect_message: syn::LitStr,
225 struct_ident: Ident,
226}
227
228impl SlotField {
229 fn new(field: &Field, index: usize, struct_ident: &Ident) -> syn::Result<Self> {
230 let field_ident = field
231 .ident
232 .clone()
233 .ok_or_else(|| syn::Error::new(field.span(), "slot fields must be named"))?;
234 let field_name = ident_name(&field_ident);
235 let const_ident = format_ident!("{}", to_screaming_snake_case(&field_name));
236 let slot_type = SlotFieldType::from_syn_type(&field.ty)?;
237 let expect_message = syn::LitStr::new(
238 &format!("{} {} slot", ident_name(struct_ident), field_name),
239 field.span(),
240 );
241
242 Ok(Self {
243 field_ident,
244 field_vis: field.vis.clone(),
245 const_ident,
246 slot_name: field_name,
247 slot_type,
248 index,
249 expect_message,
250 struct_ident: struct_ident.clone(),
251 })
252 }
253
254 fn idx_const(&self) -> proc_macro2::TokenStream {
255 let const_ident = &self.const_ident;
256 let index = self.index;
257 let slot_name = &self.slot_name;
258
259 quote! {
260 #[doc = concat!("Slot index for `", #slot_name, "`.")]
261 pub const #const_ident: usize = #index;
262 }
263 }
264
265 fn view_field(&self) -> proc_macro2::TokenStream {
266 let field_ident = &self.field_ident;
267 let field_vis = &self.field_vis;
268 let ty = self.slot_type.view_field_ty();
269
270 quote! {
271 #field_vis #field_ident: #ty
272 }
273 }
274
275 fn view_from_slots(&self) -> proc_macro2::TokenStream {
276 let field_ident = &self.field_ident;
277 let struct_ident = &self.struct_ident;
278 let const_ident = &self.const_ident;
279 let expect_message = &self.expect_message;
280
281 match self.slot_type {
282 SlotFieldType::Required => quote! {
283 #field_ident: ::vortex_error::VortexExpect::vortex_expect(
284 slots[#struct_ident::#const_ident].as_ref(),
285 #expect_message,
286 )
287 },
288 SlotFieldType::Optional => quote! {
289 #field_ident: slots[#struct_ident::#const_ident].as_ref()
290 },
291 }
292 }
293
294 fn view_to_owned(&self) -> proc_macro2::TokenStream {
295 let field_ident = &self.field_ident;
296
297 match self.slot_type {
298 SlotFieldType::Required => quote! {
299 #field_ident: ::std::clone::Clone::clone(self.#field_ident)
300 },
301 SlotFieldType::Optional => quote! {
302 #field_ident: self.#field_ident.cloned()
303 },
304 }
305 }
306
307 fn owned_from_slots(&self) -> proc_macro2::TokenStream {
308 let field_ident = &self.field_ident;
309 let struct_ident = &self.struct_ident;
310 let const_ident = &self.const_ident;
311 let expect_message = &self.expect_message;
312
313 match self.slot_type {
314 SlotFieldType::Required => quote! {
315 #field_ident: ::vortex_error::VortexExpect::vortex_expect(
316 slots[#struct_ident::#const_ident].take(),
317 #expect_message,
318 )
319 },
320 SlotFieldType::Optional => quote! {
321 #field_ident: slots[#struct_ident::#const_ident].take()
322 },
323 }
324 }
325
326 fn storage_slot(&self) -> proc_macro2::TokenStream {
327 let field_ident = &self.field_ident;
328
329 match self.slot_type {
330 SlotFieldType::Required => quote! {
331 Some(self.#field_ident)
332 },
333 SlotFieldType::Optional => quote! {
334 self.#field_ident
335 },
336 }
337 }
338
339 fn ext_method(&self) -> proc_macro2::TokenStream {
340 let field_ident = &self.field_ident;
341 let struct_ident = &self.struct_ident;
342 let const_ident = &self.const_ident;
343 let expect_message = &self.expect_message;
344
345 match self.slot_type {
346 SlotFieldType::Required => quote! {
347 #[inline]
348 fn #field_ident(&self) -> &::vortex_array::ArrayRef {
349 ::vortex_error::VortexExpect::vortex_expect(
350 self.as_ref().slots()[#struct_ident::#const_ident].as_ref(),
351 #expect_message,
352 )
353 }
354 },
355 SlotFieldType::Optional => quote! {
356 #[inline]
357 fn #field_ident(&self) -> Option<&::vortex_array::ArrayRef> {
358 self.as_ref().slots()[#struct_ident::#const_ident].as_ref()
359 }
360 },
361 }
362 }
363}
364
365#[derive(Clone, Copy)]
366enum SlotFieldType {
367 Required,
368 Optional,
369}
370
371impl SlotFieldType {
372 fn from_syn_type(ty: &Type) -> syn::Result<Self> {
373 if is_array_ref_type(ty) {
374 return Ok(Self::Required);
375 }
376
377 if let Some(inner_ty) = option_inner_type(ty)
378 && is_array_ref_type(inner_ty)
379 {
380 return Ok(Self::Optional);
381 }
382
383 Err(syn::Error::new(
384 ty.span(),
385 "#[array_slots] fields must be ArrayRef or Option<ArrayRef>",
386 ))
387 }
388
389 fn view_field_ty(self) -> proc_macro2::TokenStream {
390 match self {
391 Self::Required => quote! { &'a ::vortex_array::ArrayRef },
392 Self::Optional => quote! { Option<&'a ::vortex_array::ArrayRef> },
393 }
394 }
395}
396
397fn is_array_ref_type(ty: &Type) -> bool {
398 matches!(
399 ty,
400 Type::Path(type_path)
401 if type_path.qself.is_none()
402 && type_path
403 .path
404 .segments
405 .last()
406 .is_some_and(|segment| segment.ident == "ArrayRef")
407 )
408}
409
410fn option_inner_type(ty: &Type) -> Option<&Type> {
411 let Type::Path(type_path) = ty else {
412 return None;
413 };
414 let segment = type_path.path.segments.last()?;
415 if segment.ident != "Option" {
416 return None;
417 }
418
419 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
420 return None;
421 };
422
423 match args.args.first()? {
424 syn::GenericArgument::Type(inner_ty) => Some(inner_ty),
425 _ => None,
426 }
427}
428
429fn ident_name(ident: &Ident) -> String {
430 ident.to_string().trim_start_matches("r#").to_owned()
431}
432
433fn to_screaming_snake_case(name: &str) -> String {
434 let mut result = String::with_capacity(name.len());
435 let mut prev_is_lower_or_digit = false;
436
437 for ch in name.chars() {
438 if ch.is_ascii_uppercase() && prev_is_lower_or_digit {
439 result.push('_');
440 }
441 result.push(ch.to_ascii_uppercase());
442 prev_is_lower_or_digit = ch.is_ascii_lowercase() || ch.is_ascii_digit();
443 }
444
445 result
446}