vortex_array/arrays/extension/
mod.rs

1use std::sync::Arc;
2
3use vortex_dtype::{DType, ExtDType, ExtID};
4use vortex_error::VortexResult;
5use vortex_scalar::Scalar;
6
7use crate::stats::{ArrayStats, StatsSetRef};
8use crate::vtable::{
9    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityChild,
10    ValidityVTableFromChild, VisitorVTable,
11};
12use crate::{
13    Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef,
14    IntoArray, vtable,
15};
16
17mod compute;
18mod serde;
19
20vtable!(Extension);
21
22impl VTable for ExtensionVTable {
23    type Array = ExtensionArray;
24    type Encoding = ExtensionEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromChild;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = NotSupported;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.ext")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(ExtensionEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct ExtensionEncoding;
46
47#[derive(Clone, Debug)]
48pub struct ExtensionArray {
49    dtype: DType,
50    storage: ArrayRef,
51    stats_set: ArrayStats,
52}
53
54impl ExtensionArray {
55    pub fn new(ext_dtype: Arc<ExtDType>, storage: ArrayRef) -> Self {
56        assert_eq!(
57            ext_dtype.storage_dtype(),
58            storage.dtype(),
59            "ExtensionArray: storage_dtype must match storage array DType",
60        );
61        Self {
62            dtype: DType::Extension(ext_dtype),
63            storage,
64            stats_set: ArrayStats::default(),
65        }
66    }
67
68    pub fn ext_dtype(&self) -> &Arc<ExtDType> {
69        let DType::Extension(ext) = &self.dtype else {
70            unreachable!("ExtensionArray: dtype must be an ExtDType")
71        };
72        ext
73    }
74
75    pub fn storage(&self) -> &ArrayRef {
76        &self.storage
77    }
78
79    #[allow(dead_code)]
80    #[inline]
81    pub fn id(&self) -> &ExtID {
82        self.ext_dtype().id()
83    }
84}
85
86impl ArrayVTable<ExtensionVTable> for ExtensionVTable {
87    fn len(array: &ExtensionArray) -> usize {
88        array.storage.len()
89    }
90
91    fn dtype(array: &ExtensionArray) -> &DType {
92        &array.dtype
93    }
94
95    fn stats(array: &ExtensionArray) -> StatsSetRef<'_> {
96        array.stats_set.to_ref(array.as_ref())
97    }
98}
99
100impl ValidityChild<ExtensionVTable> for ExtensionVTable {
101    fn validity_child(array: &ExtensionArray) -> &dyn Array {
102        array.storage.as_ref()
103    }
104}
105
106impl CanonicalVTable<ExtensionVTable> for ExtensionVTable {
107    fn canonicalize(array: &ExtensionArray) -> VortexResult<Canonical> {
108        Ok(Canonical::Extension(array.clone()))
109    }
110}
111
112impl OperationsVTable<ExtensionVTable> for ExtensionVTable {
113    fn slice(array: &ExtensionArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
114        Ok(ExtensionArray::new(
115            array.ext_dtype().clone(),
116            array.storage().slice(start, stop)?,
117        )
118        .into_array())
119    }
120
121    fn scalar_at(array: &ExtensionArray, index: usize) -> VortexResult<Scalar> {
122        Ok(Scalar::extension(
123            array.ext_dtype().clone(),
124            array.storage().scalar_at(index)?,
125        ))
126    }
127}
128
129impl VisitorVTable<ExtensionVTable> for ExtensionVTable {
130    fn visit_buffers(_array: &ExtensionArray, _visitor: &mut dyn ArrayBufferVisitor) {}
131
132    fn visit_children(array: &ExtensionArray, visitor: &mut dyn ArrayChildVisitor) {
133        visitor.visit_child("storage", array.storage.as_ref());
134    }
135}