1use std::any::Any;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27use arrow_array::Array as _;
28use arrow_array::ArrayRef as ArrowArrayRef;
29use arrow_array::RecordBatch;
30use arrow_array::make_array;
31use arrow_schema::DataType;
32use arrow_schema::Field;
33use arrow_schema::Fields;
34use arrow_schema::Schema;
35use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY;
36use arrow_schema::extension::ExtensionType;
37use tracing::debug;
38use tracing::trace;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_session::Ref;
43use vortex_session::SessionExt;
44use vortex_session::SessionVar;
45use vortex_session::registry::Id;
46
47use crate::ArrayRef;
48use crate::ExecutionCtx;
49use crate::IntoArray;
50use crate::arc_swap_map::ArcSwapMap;
51use crate::arrays::StructArray;
52use crate::arrow::FromArrowArray;
53use crate::arrow::convert::nulls;
54use crate::arrow::convert::remove_nulls;
55use crate::arrow::executor::execute_arrow_naive;
56use crate::dtype::DType;
57use crate::dtype::FieldName;
58use crate::dtype::FieldNames;
59use crate::dtype::Nullability;
60use crate::dtype::StructFields;
61use crate::dtype::arrow::FromArrowType;
62use crate::dtype::arrow::to_data_type_naive;
63use crate::dtype::extension::ExtId;
64use crate::extension::datetime::AnyTemporal;
65use crate::extension::uuid::Uuid;
66use crate::validity::Validity;
67
68pub enum ArrowExport {
74 Unsupported(ArrayRef),
76 Exported(ArrowArrayRef),
78}
79
80pub enum ArrowImport {
86 Unsupported(ArrowArrayRef),
88 Imported(ArrayRef),
90}
91
92pub trait ArrowExportVTable: 'static + Send + Sync + Debug {
97 fn arrow_ext_id(&self) -> Id;
99
100 fn vortex_id(&self) -> Id;
104
105 fn to_arrow_field(
108 &self,
109 name: &str,
110 dtype: &DType,
111 session: &ArrowSession,
112 ) -> VortexResult<Option<Field>>;
113
114 fn execute_arrow(
119 &self,
120 array: ArrayRef,
121 target: &Field,
122 ctx: &mut ExecutionCtx,
123 ) -> VortexResult<ArrowExport>;
124}
125
126pub trait ArrowImportVTable: 'static + Send + Sync + Debug {
133 fn arrow_ext_id(&self) -> Id;
135
136 #[allow(clippy::wrong_self_convention)]
139 fn from_arrow_field(&self, field: &Field) -> VortexResult<Option<DType>>;
140
141 #[allow(clippy::wrong_self_convention)]
146 fn from_arrow_array(
147 &self,
148 array: ArrowArrayRef,
149 field: &Field,
150 dtype: &DType,
151 ) -> VortexResult<ArrowImport>;
152}
153
154pub type ArrowExportVTableRef = Arc<dyn ArrowExportVTable>;
155pub type ArrowImportVTableRef = Arc<dyn ArrowImportVTable>;
156
157#[derive(Debug)]
167pub struct ArrowSession {
168 exporters: ArcSwapMap<Id, Arc<[ArrowExportVTableRef]>>,
169 exporters_by_vortex: ArcSwapMap<ExtId, Arc<[ArrowExportVTableRef]>>,
170 importers: ArcSwapMap<Id, Arc<[ArrowImportVTableRef]>>,
171}
172
173impl Default for ArrowSession {
174 fn default() -> Self {
175 let session = Self {
176 exporters: ArcSwapMap::default(),
177 exporters_by_vortex: ArcSwapMap::default(),
178 importers: ArcSwapMap::default(),
179 };
180
181 session.register_exporter(Arc::new(Uuid));
182 session.register_importer(Arc::new(Uuid));
183
184 session
185 }
186}
187
188impl ArrowSession {
189 pub fn register_exporter(&self, exporter: ArrowExportVTableRef) {
192 self.exporters.push(
193 exporter.arrow_ext_id(),
194 ArrowExportVTableRef::clone(&exporter),
195 );
196 self.exporters_by_vortex
197 .push(exporter.vortex_id(), exporter);
198 }
199
200 pub fn register_importer(&self, importer: ArrowImportVTableRef) {
202 self.importers.push(importer.arrow_ext_id(), importer);
203 }
204
205 fn exporters(&self, id: &Id) -> Arc<[ArrowExportVTableRef]> {
206 self.exporters.get(id).unwrap_or_else(|| Arc::from([]))
207 }
208
209 fn exporters_by_vortex(&self, id: &Id) -> Arc<[ArrowExportVTableRef]> {
210 self.exporters_by_vortex
211 .get(id)
212 .unwrap_or_else(|| Arc::from([]))
213 }
214
215 fn importers(&self, id: &Id) -> Arc<[ArrowImportVTableRef]> {
216 self.importers.get(id).unwrap_or_else(|| Arc::from([]))
217 }
218
219 pub fn to_arrow_field(&self, name: &str, dtype: &DType) -> VortexResult<Field> {
224 match dtype {
226 DType::List(elem_dtype, nullability) => {
227 let elem_field = self.to_arrow_field(Field::LIST_FIELD_DEFAULT_NAME, elem_dtype)?;
228 Ok(Field::new_list(name, elem_field, nullability.is_nullable()))
229 }
230 DType::FixedSizeList(elem_dtype, elem_size, nullability) => {
231 let elem_field = self.to_arrow_field(Field::LIST_FIELD_DEFAULT_NAME, elem_dtype)?;
232 Ok(Field::new_fixed_size_list(
233 name,
234 elem_field,
235 (*elem_size).try_into()?,
236 nullability.is_nullable(),
237 ))
238 }
239 DType::Struct(fields, nullability) => {
240 let arrow_fields = Fields::from_iter(
241 fields
242 .fields()
243 .zip(fields.names().iter())
244 .map(|(field, name)| self.to_arrow_field(name.as_ref(), &field))
245 .collect::<VortexResult<Vec<_>>>()?,
246 );
247 Ok(Field::new_struct(
248 name,
249 arrow_fields,
250 nullability.is_nullable(),
251 ))
252 }
253 DType::Extension(ext) if !ext.is::<AnyTemporal>() => {
254 for plugin in self.exporters_by_vortex(&ext.id()).iter() {
255 if let Some(field) =
256 plugin.to_arrow_field(name, &DType::Extension(ext.clone()), self)?
257 {
258 return Ok(field);
259 }
260 }
261 vortex_bail!("extension type cannot be converted to Arrow without a plugin: {ext}");
262 }
263 DType::Variant(_) => {
264 Ok(Field::new(
268 name,
269 DataType::Struct(
270 vec![
271 Field::new("metadata", DataType::BinaryView, dtype.is_nullable()),
272 Field::new("value", DataType::BinaryView, dtype.is_nullable()),
273 ]
274 .into(),
275 ),
276 dtype.is_nullable(),
277 )
278 .with_metadata(
279 [(
280 EXTENSION_TYPE_NAME_KEY.to_string(),
281 "arrow.parquet.variant".to_string(),
282 )]
283 .into(),
284 ))
285 }
286 _ => Ok(Field::new(
287 name,
288 to_data_type_naive(dtype)?,
289 dtype.is_nullable(),
290 )),
291 }
292 }
293
294 pub fn to_arrow_schema(&self, dtype: &DType) -> VortexResult<Schema> {
298 let DType::Struct(struct_dtype, _) = dtype else {
299 vortex_error::vortex_bail!(
300 "to_arrow_schema requires a top-level struct dtype, got {dtype}"
301 );
302 };
303 let mut fields = Vec::with_capacity(struct_dtype.names().len());
304 for (name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
305 fields.push(self.to_arrow_field(name.as_ref(), &field_dtype)?);
306 }
307 Ok(Schema::new(fields))
308 }
309
310 pub fn from_arrow_field(&self, field: &Field) -> VortexResult<DType> {
319 if let Some(name) = field.metadata().get(EXTENSION_TYPE_NAME_KEY) {
320 for plugin in self.importers(&Id::new(name)).iter() {
321 if let Some(dtype) = plugin.from_arrow_field(field)? {
322 return Ok(dtype);
323 }
324 }
325 }
326 let nullability: Nullability = field.is_nullable().into();
327 Ok(match field.data_type() {
328 DataType::List(elem)
329 | DataType::LargeList(elem)
330 | DataType::ListView(elem)
331 | DataType::LargeListView(elem) => {
332 DType::List(Arc::new(self.from_arrow_field(elem.as_ref())?), nullability)
333 }
334 DataType::FixedSizeList(elem, size) => DType::FixedSizeList(
335 Arc::new(self.from_arrow_field(elem.as_ref())?),
336 *size as u32,
337 nullability,
338 ),
339 DataType::Struct(fields) => {
340 let entries = fields
341 .iter()
342 .map(|f| {
343 self.from_arrow_field(f)
344 .map(|dt| (FieldName::from(f.name().as_str()), dt))
345 })
346 .collect::<VortexResult<Vec<_>>>()?;
347 DType::Struct(StructFields::from_iter(entries), nullability)
348 }
349 _ => DType::from_arrow(field),
350 })
351 }
352
353 pub fn from_arrow_schema(&self, schema: &Schema) -> VortexResult<DType> {
357 let entries = schema
358 .fields()
359 .iter()
360 .map(|f| {
361 self.from_arrow_field(f)
362 .map(|dt| (FieldName::from(f.name().as_str()), dt))
363 })
364 .collect::<VortexResult<Vec<_>>>()?;
365 Ok(DType::Struct(
366 StructFields::from_iter(entries),
367 Nullability::NonNullable,
368 ))
369 }
370
371 pub fn from_arrow_record_batch(
379 &self,
380 batch: RecordBatch,
381 schema: &Schema,
382 ) -> VortexResult<ArrayRef> {
383 vortex_ensure!(
384 batch.num_columns() == schema.fields().len(),
385 "RecordBatch has {} columns but schema has {} fields",
386 batch.num_columns(),
387 schema.fields().len()
388 );
389 let length = batch.num_rows();
390 let names = FieldNames::from_iter(
391 schema
392 .fields()
393 .iter()
394 .map(|f| FieldName::from(f.name().as_str())),
395 );
396 let mut columns = Vec::with_capacity(schema.fields().len());
397 for (col, field) in batch.columns().iter().zip(schema.fields().iter()) {
398 columns.push(self.from_arrow_array(ArrowArrayRef::clone(col), field)?);
399 }
400 Ok(StructArray::try_new(names, columns, length, Validity::NonNullable)?.into_array())
401 }
402
403 pub fn execute_arrow(
411 &self,
412 array: ArrayRef,
413 target: Option<&Field>,
414 ctx: &mut ExecutionCtx,
415 ) -> VortexResult<ArrowArrayRef> {
416 let arrow_field;
420 let target_field = match target {
421 Some(field) => field,
422 None => {
423 let session = ctx.session().clone();
424 arrow_field = session.arrow().to_arrow_field("", array.dtype())?;
425 &arrow_field
426 }
427 };
428
429 if let Some(arrow_ext_name) = target_field.metadata().get(EXTENSION_TYPE_NAME_KEY) {
430 let len = array.len();
433 let mut current = array;
434
435 for plugin in self.exporters(&Id::new(arrow_ext_name)).iter() {
436 trace!(
437 plugin = ?plugin,
438 extension_name = arrow_ext_name,
439 "probing plugin for converting Arrow array"
440 );
441
442 match plugin.execute_arrow(current, target_field, ctx)? {
443 ArrowExport::Exported(arrow) => {
444 vortex_ensure!(
445 arrow.len() == len,
446 "Arrow array length does not match Vortex array length after conversion to {:?}",
447 arrow
448 );
449 return Ok(arrow);
450 }
451 ArrowExport::Unsupported(array) => current = array,
452 }
453 }
454
455 debug!(
456 extension_id = arrow_ext_name,
457 data_type = ?target_field.data_type(),
458 "unsupported Arrow extension type encountered, falling back to naive execution"
459 );
460
461 return execute_arrow_naive(current, Some(target_field.data_type()), ctx);
462 }
463
464 execute_arrow_naive(array, target.map(|field| field.data_type()), ctx)
465 }
466
467 pub fn from_arrow_array(&self, array: ArrowArrayRef, field: &Field) -> VortexResult<ArrayRef> {
477 if let Some(extension_name) = field.metadata().get(EXTENSION_TYPE_NAME_KEY) {
478 let importers = self.importers(&Id::new(extension_name));
479 if !importers.is_empty() {
480 let dtype = self.from_arrow_field(field)?;
481 let mut current = array;
482 for plugin in importers.iter() {
483 match plugin.from_arrow_array(current, field, &dtype)? {
484 ArrowImport::Imported(arr) => return Ok(arr),
485 ArrowImport::Unsupported(arr) => current = arr,
486 }
487 }
488 return ArrayRef::from_arrow(current.as_ref(), field.is_nullable());
489 }
490 }
491 self.from_arrow_array_canonical(array, field)
492 }
493
494 #[allow(clippy::wrong_self_convention)]
497 fn from_arrow_array_canonical(
498 &self,
499 array: ArrowArrayRef,
500 field: &Field,
501 ) -> VortexResult<ArrayRef> {
502 use arrow_array::cast::AsArray;
503
504 match field.data_type() {
505 DataType::Struct(fields) => {
506 let arrow_struct = array.as_struct();
507 let names = FieldNames::from_iter(
508 fields.iter().map(|f| FieldName::from(f.name().as_str())),
509 );
510 let columns = arrow_struct
511 .columns()
512 .iter()
513 .zip(fields.iter())
514 .map(|(col, child_field)| {
515 let inner = if col.null_count() > 0 && !child_field.is_nullable() {
518 make_array(remove_nulls(col.to_data())?)
519 } else {
520 ArrowArrayRef::clone(col)
521 };
522 self.from_arrow_array(inner, child_field.as_ref())
523 })
524 .collect::<VortexResult<Vec<_>>>()?;
525 let validity = nulls(arrow_struct.nulls(), field.is_nullable())?;
526 Ok(
527 StructArray::try_new(names, columns, arrow_struct.len(), validity)?
528 .into_array(),
529 )
530 }
531 DataType::List(elem_field) => {
532 let list = array.as_list::<i32>();
533 let elements = self
534 .from_arrow_array(ArrowArrayRef::clone(list.values()), elem_field.as_ref())?;
535 let offsets = list.offsets().clone().into_array();
536 let validity = nulls(list.nulls(), field.is_nullable())?;
537 Ok(crate::arrays::ListArray::try_new(elements, offsets, validity)?.into_array())
538 }
539 DataType::LargeList(elem_field) => {
540 let list = array.as_list::<i64>();
541 let elements = self
542 .from_arrow_array(ArrowArrayRef::clone(list.values()), elem_field.as_ref())?;
543 let offsets = list.offsets().clone().into_array();
544 let validity = nulls(list.nulls(), field.is_nullable())?;
545 Ok(crate::arrays::ListArray::try_new(elements, offsets, validity)?.into_array())
546 }
547 DataType::FixedSizeList(elem_field, list_size) => {
548 let fsl = array.as_fixed_size_list();
549 let elements =
550 self.from_arrow_array(ArrowArrayRef::clone(fsl.values()), elem_field.as_ref())?;
551 let validity = nulls(fsl.nulls(), field.is_nullable())?;
552 Ok(crate::arrays::FixedSizeListArray::try_new(
553 elements,
554 *list_size as u32,
555 validity,
556 fsl.len(),
557 )?
558 .into_array())
559 }
560 DataType::ListView(elem_field) => {
561 let list = array.as_list_view::<i32>();
562 let elements = self
563 .from_arrow_array(ArrowArrayRef::clone(list.values()), elem_field.as_ref())?;
564 let offsets = list.offsets().clone().into_array();
565 let sizes = list.sizes().clone().into_array();
566 let validity = nulls(list.nulls(), field.is_nullable())?;
567 Ok(
568 crate::arrays::ListViewArray::try_new(elements, offsets, sizes, validity)?
569 .into_array(),
570 )
571 }
572 DataType::LargeListView(elem_field) => {
573 let list = array.as_list_view::<i64>();
574 let elements = self
575 .from_arrow_array(ArrowArrayRef::clone(list.values()), elem_field.as_ref())?;
576 let offsets = list.offsets().clone().into_array();
577 let sizes = list.sizes().clone().into_array();
578 let validity = nulls(list.nulls(), field.is_nullable())?;
579 Ok(
580 crate::arrays::ListViewArray::try_new(elements, offsets, sizes, validity)?
581 .into_array(),
582 )
583 }
584 _ => ArrayRef::from_arrow(array.as_ref(), field.is_nullable()),
585 }
586 }
587}
588
589pub(crate) fn has_valid_extension_type<E: ExtensionType>(field: &Field) -> bool {
593 if field.extension_type_name() != Some(E::NAME) {
594 return false;
595 }
596
597 E::try_new_from_field_metadata(field.data_type(), field.metadata()).is_ok()
598}
599
600impl SessionVar for ArrowSession {
601 fn as_any(&self) -> &dyn Any {
602 self
603 }
604
605 fn as_any_mut(&mut self) -> &mut dyn Any {
606 self
607 }
608}
609
610pub trait ArrowSessionExt: SessionExt {
612 fn arrow(&self) -> Ref<'_, ArrowSession>;
614}
615
616impl<S: SessionExt> ArrowSessionExt for S {
617 fn arrow(&self) -> Ref<'_, ArrowSession> {
618 self.get::<ArrowSession>()
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use std::sync::Arc;
625
626 use arrow_array::FixedSizeBinaryArray;
627 use arrow_array::cast::AsArray;
628 use arrow_schema::DataType;
629 use arrow_schema::Field;
630 use arrow_schema::extension::Uuid as ArrowUuid;
631 use vortex_error::VortexResult;
632
633 use super::*;
634 use crate::LEGACY_SESSION;
635 use crate::VortexSessionExecute;
636 use crate::dtype::DType;
637 use crate::dtype::FieldName;
638 use crate::dtype::Nullability;
639 use crate::dtype::PType;
640 use crate::dtype::StructFields;
641 use crate::dtype::extension::ExtDType;
642 use crate::dtype::extension::ExtVTable;
643 use crate::extension::uuid::Uuid;
644 use crate::extension::uuid::UuidMetadata;
645
646 fn uuid_dtype(nullable: bool) -> DType {
647 let storage = DType::FixedSizeList(
648 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
649 16,
650 nullable.into(),
651 );
652 DType::Extension(
653 ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), storage)
654 .expect("uuid ext dtype")
655 .erased(),
656 )
657 }
658
659 #[test]
660 fn to_arrow_field_top_level_uuid_carries_extension_metadata() -> VortexResult<()> {
661 let session = ArrowSession::default();
662 let field = session.to_arrow_field("id", &uuid_dtype(false))?;
663 assert!(has_valid_extension_type::<ArrowUuid>(&field));
664 Ok(())
665 }
666
667 #[test]
668 fn to_arrow_field_struct_with_nested_uuid_preserves_metadata() -> VortexResult<()> {
669 let session = ArrowSession::default();
670 let dtype = DType::Struct(
671 StructFields::from_iter([(FieldName::from("id"), uuid_dtype(false))]),
672 Nullability::NonNullable,
673 );
674 let field = session.to_arrow_field("row", &dtype)?;
675 let DataType::Struct(inner) = field.data_type() else {
676 panic!("expected Struct, got {:?}", field.data_type());
677 };
678 assert_eq!(inner.len(), 1);
679 assert_eq!(inner[0].data_type(), &DataType::FixedSizeBinary(16));
680 assert!(has_valid_extension_type::<ArrowUuid>(&inner[0]));
681 Ok(())
682 }
683
684 #[test]
685 fn to_arrow_field_list_of_uuid_preserves_metadata() -> VortexResult<()> {
686 let session = ArrowSession::default();
687 let dtype = DType::List(Arc::new(uuid_dtype(true)), Nullability::NonNullable);
688 let field = session.to_arrow_field("ids", &dtype)?;
689 let DataType::List(elem) = field.data_type() else {
690 panic!("expected List, got {:?}", field.data_type());
691 };
692 assert!(has_valid_extension_type::<ArrowUuid>(elem));
693 Ok(())
694 }
695
696 #[test]
697 fn to_arrow_field_fixed_size_list_of_uuid_preserves_metadata() -> VortexResult<()> {
698 let session = ArrowSession::default();
699 let dtype = DType::FixedSizeList(Arc::new(uuid_dtype(false)), 3, Nullability::NonNullable);
700 let field = session.to_arrow_field("triple", &dtype)?;
701 let DataType::FixedSizeList(elem, size) = field.data_type() else {
702 panic!("expected FixedSizeList, got {:?}", field.data_type());
703 };
704 assert_eq!(*size, 3);
705 assert!(has_valid_extension_type::<ArrowUuid>(elem));
706 Ok(())
707 }
708
709 #[test]
710 fn to_arrow_schema_struct_of_struct_uuid() -> VortexResult<()> {
711 let session = ArrowSession::default();
712 let inner = DType::Struct(
713 StructFields::from_iter([(FieldName::from("id"), uuid_dtype(true))]),
714 Nullability::NonNullable,
715 );
716 let outer = DType::Struct(
717 StructFields::from_iter([(FieldName::from("payload"), inner)]),
718 Nullability::NonNullable,
719 );
720 let schema = session.to_arrow_schema(&outer)?;
721 let payload = schema.field(0);
722 let DataType::Struct(inner_fields) = payload.data_type() else {
723 panic!("expected Struct, got {:?}", payload.data_type());
724 };
725 assert!(has_valid_extension_type::<ArrowUuid>(&inner_fields[0]));
726 Ok(())
727 }
728
729 #[test]
730 fn from_arrow_field_recurses_into_nested_uuid() -> VortexResult<()> {
731 let session = ArrowSession::default();
732 let mut elem = Field::new("item", DataType::FixedSizeBinary(16), false);
733 elem.try_with_extension_type(ArrowUuid)?;
734 let outer = Field::new("ids", DataType::List(Arc::new(elem)), false);
735
736 let dtype = session.from_arrow_field(&outer)?;
737 let DType::List(inner_dt, _) = dtype else {
738 panic!("expected List dtype, got {dtype}");
739 };
740 assert!(
741 matches!(inner_dt.as_ref(), DType::Extension(ext) if ext.id() == Uuid.id()),
742 "expected Uuid extension element, got {inner_dt}",
743 );
744 Ok(())
745 }
746
747 #[test]
748 fn schema_roundtrip_preserves_nested_uuid() -> VortexResult<()> {
749 let session = ArrowSession::default();
750 let dtype = DType::Struct(
751 StructFields::from_iter([
752 (FieldName::from("id"), uuid_dtype(false)),
753 (
754 FieldName::from("ids"),
755 DType::List(Arc::new(uuid_dtype(true)), Nullability::NonNullable),
756 ),
757 ]),
758 Nullability::NonNullable,
759 );
760 let schema = session.to_arrow_schema(&dtype)?;
761 let roundtripped = session.from_arrow_schema(&schema)?;
762 assert_eq!(roundtripped, dtype);
763 Ok(())
764 }
765
766 #[test]
767 fn execute_arrow_target_none_preserves_top_level_uuid_metadata() -> VortexResult<()> {
768 let mut ctx = LEGACY_SESSION.create_execution_ctx();
769 let session = LEGACY_SESSION.arrow();
770
771 let mut field = Field::new("id", DataType::FixedSizeBinary(16), false);
772 field.try_with_extension_type(ArrowUuid)?;
773 let arrow_array: ArrowArrayRef = Arc::new(FixedSizeBinaryArray::try_from_iter(
774 [*b"0123456789abcdef", *b"fedcba9876543210"].into_iter(),
775 )?);
776
777 let vortex_array = session.from_arrow_array(arrow_array, &field)?;
778
779 let vortex_ext = vortex_array.dtype().as_extension();
780 assert!(vortex_ext.is::<Uuid>());
781
782 let exported = session.execute_arrow(vortex_array, None, &mut ctx)?;
783 assert_eq!(exported.data_type(), &DataType::FixedSizeBinary(16));
784 let fsb = exported.as_fixed_size_binary();
785 assert_eq!(fsb.len(), 2);
786 assert_eq!(fsb.value(0), b"0123456789abcdef");
787 assert_eq!(fsb.value(1), b"fedcba9876543210");
788 Ok(())
789 }
790}