vortex_array/arrays/extension/
mod.rs

1use std::sync::Arc;
2
3use vortex_dtype::{DType, ExtDType, ExtID};
4use vortex_error::VortexResult;
5use vortex_mask::Mask;
6
7use crate::array::{ArrayCanonicalImpl, ArrayValidityImpl};
8use crate::stats::{ArrayStats, Stat, StatsSet, StatsSetRef};
9use crate::variants::ExtensionArrayTrait;
10use crate::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
11use crate::{
12    Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayVariantsImpl, Canonical, EmptyMetadata,
13    Encoding, EncodingId,
14};
15mod compute;
16mod serde;
17
18#[derive(Clone, Debug)]
19pub struct ExtensionArray {
20    dtype: DType,
21    storage: ArrayRef,
22    stats_set: ArrayStats,
23}
24
25pub struct ExtensionEncoding;
26impl Encoding for ExtensionEncoding {
27    type Array = ExtensionArray;
28    type Metadata = EmptyMetadata;
29}
30
31impl EncodingVTable for ExtensionEncoding {
32    fn id(&self) -> EncodingId {
33        EncodingId::new_ref("vortex.ext")
34    }
35}
36
37impl ExtensionArray {
38    pub fn new(ext_dtype: Arc<ExtDType>, storage: ArrayRef) -> Self {
39        assert_eq!(
40            ext_dtype.storage_dtype(),
41            storage.dtype(),
42            "ExtensionArray: storage_dtype must match storage array DType",
43        );
44        Self {
45            dtype: DType::Extension(ext_dtype),
46            storage,
47            stats_set: ArrayStats::default(),
48        }
49    }
50
51    pub fn storage(&self) -> &ArrayRef {
52        &self.storage
53    }
54
55    #[allow(dead_code)]
56    #[inline]
57    pub fn id(&self) -> &ExtID {
58        self.ext_dtype().id()
59    }
60}
61
62impl ArrayImpl for ExtensionArray {
63    type Encoding = ExtensionEncoding;
64
65    fn _len(&self) -> usize {
66        self.storage.len()
67    }
68
69    fn _dtype(&self) -> &DType {
70        &self.dtype
71    }
72
73    fn _vtable(&self) -> VTableRef {
74        VTableRef::new_ref(&ExtensionEncoding)
75    }
76}
77
78impl ArrayStatisticsImpl for ExtensionArray {
79    fn _stats_ref(&self) -> StatsSetRef<'_> {
80        self.stats_set.to_ref(self)
81    }
82}
83
84impl ArrayCanonicalImpl for ExtensionArray {
85    fn _to_canonical(&self) -> VortexResult<Canonical> {
86        Ok(Canonical::Extension(self.clone()))
87    }
88}
89
90impl ArrayValidityImpl for ExtensionArray {
91    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
92        self.storage.is_valid(index)
93    }
94
95    fn _all_valid(&self) -> VortexResult<bool> {
96        self.storage.all_valid()
97    }
98
99    fn _all_invalid(&self) -> VortexResult<bool> {
100        self.storage.all_invalid()
101    }
102
103    fn _validity_mask(&self) -> VortexResult<Mask> {
104        self.storage.validity_mask()
105    }
106}
107
108impl ArrayVariantsImpl for ExtensionArray {
109    fn _as_extension_typed(&self) -> Option<&dyn ExtensionArrayTrait> {
110        Some(self)
111    }
112}
113
114impl ExtensionArrayTrait for ExtensionArray {
115    fn storage_data(&self) -> ArrayRef {
116        self.storage().clone()
117    }
118}
119
120impl StatisticsVTable<&ExtensionArray> for ExtensionEncoding {
121    fn compute_statistics(&self, array: &'_ ExtensionArray, stat: Stat) -> VortexResult<StatsSet> {
122        // No need to cast the storage statistics since we return untyped ScalarValue.
123        array.storage().statistics().compute_all(&[stat])
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use vortex_buffer::buffer;
130    use vortex_dtype::PType;
131
132    use super::*;
133    use crate::IntoArray;
134    use crate::stats::{Precision, StatsProviderExt};
135
136    #[test]
137    fn compute_statistics() {
138        let ext_dtype = Arc::new(ExtDType::new(
139            ExtID::new("timestamp".into()),
140            DType::from(PType::I64).into(),
141            None,
142        ));
143        let array = ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array());
144
145        let stats = array
146            .statistics()
147            .compute_all(&[Stat::Min, Stat::Max, Stat::NullCount])
148            .unwrap();
149        let num_stats = stats.clone().into_iter().count();
150        assert!(
151            num_stats >= 3,
152            "Expected at least 3 stats, got {}",
153            num_stats
154        );
155
156        assert_eq!(stats.get_as::<i64>(Stat::Min), Some(Precision::exact(1i64)));
157        assert_eq!(
158            stats.get_as::<i64>(Stat::Max),
159            Some(Precision::exact(5_i64))
160        );
161        assert_eq!(
162            stats.get_as::<usize>(Stat::NullCount),
163            Some(Precision::exact(0usize))
164        );
165    }
166}