vortex_array/arrays/extension/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6
7use std::sync::Arc;
8
9use vortex_dtype::ExtDType;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12use vortex_scalar::Scalar;
13
14use crate::arrays::ExtensionVTable;
15use crate::arrays::extension::ExtensionArray;
16use crate::compute::{
17    FilterKernel, FilterKernelAdapter, IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts,
18    IsSortedKernel, IsSortedKernelAdapter, MaskKernel, MaskKernelAdapter, MinMaxKernel,
19    MinMaxKernelAdapter, MinMaxResult, SumKernel, SumKernelAdapter, TakeKernel, TakeKernelAdapter,
20    filter, is_constant_opts, is_sorted, is_strict_sorted, mask, min_max, sum, take,
21};
22use crate::{Array, ArrayRef, IntoArray, register_kernel};
23
24impl FilterKernel for ExtensionVTable {
25    fn filter(&self, array: &ExtensionArray, mask: &Mask) -> VortexResult<ArrayRef> {
26        Ok(
27            ExtensionArray::new(array.ext_dtype().clone(), filter(array.storage(), mask)?)
28                .into_array(),
29        )
30    }
31}
32
33register_kernel!(FilterKernelAdapter(ExtensionVTable).lift());
34
35impl MaskKernel for ExtensionVTable {
36    fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
37        let masked_storage = mask(array.storage(), mask_array)?;
38        if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
39            Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array())
40        } else {
41            // The storage dtype changed (i.e., became nullable due to masking)
42            let ext_dtype = Arc::new(ExtDType::new(
43                array.ext_dtype().id().clone(),
44                Arc::new(masked_storage.dtype().clone()),
45                array.ext_dtype().metadata().cloned(),
46            ));
47            Ok(ExtensionArray::new(ext_dtype, masked_storage).into_array())
48        }
49    }
50}
51
52register_kernel!(MaskKernelAdapter(ExtensionVTable).lift());
53
54impl SumKernel for ExtensionVTable {
55    fn sum(&self, array: &ExtensionArray) -> VortexResult<Scalar> {
56        sum(array.storage())
57    }
58}
59
60register_kernel!(SumKernelAdapter(ExtensionVTable).lift());
61
62impl TakeKernel for ExtensionVTable {
63    fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
64        let taken_storage = take(array.storage(), indices)?;
65        if taken_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
66            Ok(ExtensionArray::new(array.ext_dtype().clone(), taken_storage).into_array())
67        } else {
68            // The storage dtype changed (i.e., became nullable due to nullable indices)
69            let ext_dtype = Arc::new(ExtDType::new(
70                array.ext_dtype().id().clone(),
71                Arc::new(taken_storage.dtype().clone()),
72                array.ext_dtype().metadata().cloned(),
73            ));
74            Ok(ExtensionArray::new(ext_dtype, taken_storage).into_array())
75        }
76    }
77}
78
79register_kernel!(TakeKernelAdapter(ExtensionVTable).lift());
80
81impl MinMaxKernel for ExtensionVTable {
82    fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
83        Ok(
84            min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
85                min: Scalar::extension(array.ext_dtype().clone(), min),
86                max: Scalar::extension(array.ext_dtype().clone(), max),
87            }),
88        )
89    }
90}
91
92register_kernel!(MinMaxKernelAdapter(ExtensionVTable).lift());
93
94impl IsConstantKernel for ExtensionVTable {
95    fn is_constant(
96        &self,
97        array: &ExtensionArray,
98        opts: &IsConstantOpts,
99    ) -> VortexResult<Option<bool>> {
100        is_constant_opts(array.storage(), opts)
101    }
102}
103
104register_kernel!(IsConstantKernelAdapter(ExtensionVTable).lift());
105
106impl IsSortedKernel for ExtensionVTable {
107    fn is_sorted(&self, array: &ExtensionArray) -> VortexResult<Option<bool>> {
108        is_sorted(array.storage())
109    }
110
111    fn is_strict_sorted(&self, array: &ExtensionArray) -> VortexResult<Option<bool>> {
112        is_strict_sorted(array.storage())
113    }
114}
115
116register_kernel!(IsSortedKernelAdapter(ExtensionVTable).lift());
117
118#[cfg(test)]
119mod test {
120    use std::sync::Arc;
121
122    use rstest::rstest;
123    use vortex_buffer::buffer;
124    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
125
126    use crate::IntoArray;
127    use crate::arrays::{ExtensionArray, PrimitiveArray};
128    use crate::compute::conformance::filter::test_filter_conformance;
129    use crate::compute::conformance::take::test_take_conformance;
130
131    #[test]
132    fn test_filter_extension_array() {
133        // Create a simple extension type (e.g., UUID represented as u64)
134        let ext_dtype = ExtDType::new(
135            ExtID::new("uuid".into()),
136            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
137            None,
138        );
139
140        // Create storage array
141        let storage = buffer![1u64, 2, 3, 4, 5].into_array();
142        let array = ExtensionArray::new(Arc::new(ext_dtype), storage);
143        test_filter_conformance(array.as_ref());
144
145        // Test with nullable extension type
146        let ext_dtype_nullable = ExtDType::new(
147            ExtID::new("uuid".into()),
148            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
149            None,
150        );
151        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
152            .into_array();
153        let array = ExtensionArray::new(Arc::new(ext_dtype_nullable), storage);
154        test_filter_conformance(array.as_ref());
155    }
156
157    #[rstest]
158    #[case({
159        // Simple extension type (non-nullable u64)
160        let storage = buffer![1u64, 2, 3, 4, 5].into_array();
161        let ext_dtype = ExtDType::new(
162            ExtID::new("uuid".into()),
163            Arc::new(storage.dtype().clone()),
164            None,
165        );
166        ExtensionArray::new(Arc::new(ext_dtype), storage)
167    })]
168    #[case({
169        // Nullable extension type
170        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
171            .into_array();
172        let ext_dtype_nullable = ExtDType::new(
173            ExtID::new("uuid".into()),
174            Arc::new(storage.dtype().clone()),
175            None,
176        );
177        ExtensionArray::new(Arc::new(ext_dtype_nullable), storage)
178    })]
179    #[case({
180        // Single element
181        let storage = buffer![42u64].into_array();
182        let ext_dtype_single = ExtDType::new(
183            ExtID::new("uuid".into()),
184            Arc::new(storage.dtype().clone()),
185            None,
186        );
187        ExtensionArray::new(Arc::new(ext_dtype_single), storage)
188    })]
189    #[case({
190        // Larger array for edge cases
191        let storage = buffer![0u64..100].into_array();
192        let ext_dtype_large = ExtDType::new(
193            ExtID::new("uuid".into()),
194            Arc::new(storage.dtype().clone()),
195            None,
196        );
197        ExtensionArray::new(Arc::new(ext_dtype_large), storage)
198    })]
199    fn test_take_extension_array_conformance(#[case] array: ExtensionArray) {
200        test_take_conformance(array.as_ref());
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::sync::Arc;
207
208    use rstest::rstest;
209    use vortex_buffer::buffer;
210    use vortex_dtype::{ExtDType, ExtID};
211
212    use crate::IntoArray;
213    use crate::arrays::{ExtensionArray, PrimitiveArray};
214    use crate::compute::conformance::consistency::test_array_consistency;
215
216    #[rstest]
217    // Note: The original test_all_consistency cases for extension arrays caused errors
218    // because of unsupported extension type "uuid". We'll use simpler test cases.
219    #[case::extension_simple({
220        let storage = buffer![1u64, 2, 3, 4, 5].into_array();
221        let ext_dtype = ExtDType::new(
222            ExtID::new("test_ext".into()),
223            Arc::new(storage.dtype().clone()),
224            None,
225        );
226        ExtensionArray::new(Arc::new(ext_dtype), storage)
227    })]
228    #[case::extension_nullable({
229        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
230            .into_array();
231        let ext_dtype = ExtDType::new(
232            ExtID::new("test_ext".into()),
233            Arc::new(storage.dtype().clone()),
234            None,
235        );
236        ExtensionArray::new(Arc::new(ext_dtype), storage)
237    })]
238    // Additional test cases
239    #[case::extension_single({
240        let storage = buffer![42i32].into_array();
241        let ext_dtype = ExtDType::new(
242            ExtID::new("test_ext".into()),
243            Arc::new(storage.dtype().clone()),
244            None,
245        );
246        ExtensionArray::new(Arc::new(ext_dtype), storage)
247    })]
248    #[case::extension_large({
249        let storage = buffer![0..100i64].into_array();
250        let ext_dtype = ExtDType::new(
251            ExtID::new("test_ext".into()),
252            Arc::new(storage.dtype().clone()),
253            None,
254        );
255        ExtensionArray::new(Arc::new(ext_dtype), storage)
256    })]
257    fn test_extension_consistency(#[case] array: ExtensionArray) {
258        test_array_consistency(array.as_ref());
259    }
260}