vortex_array/arrays/extension/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6
7use std::sync::Arc;
8
9use vortex_dtype::ExtDType;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12use vortex_scalar::Scalar;
13
14use crate::arrays::ExtensionVTable;
15use crate::arrays::extension::ExtensionArray;
16use crate::compute::{
17    FilterKernel, FilterKernelAdapter, IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts,
18    IsSortedKernel, IsSortedKernelAdapter, MaskKernel, MaskKernelAdapter, MinMaxKernel,
19    MinMaxKernelAdapter, MinMaxResult, SumKernel, SumKernelAdapter, TakeKernel, TakeKernelAdapter,
20    filter, is_constant_opts, is_sorted, is_strict_sorted, mask, min_max, sum, take,
21};
22use crate::{Array, ArrayRef, IntoArray, register_kernel};
23
24impl FilterKernel for ExtensionVTable {
25    fn filter(&self, array: &ExtensionArray, mask: &Mask) -> VortexResult<ArrayRef> {
26        Ok(
27            ExtensionArray::new(array.ext_dtype().clone(), filter(array.storage(), mask)?)
28                .into_array(),
29        )
30    }
31}
32
33register_kernel!(FilterKernelAdapter(ExtensionVTable).lift());
34
35impl MaskKernel for ExtensionVTable {
36    fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
37        let masked_storage = mask(array.storage(), mask_array)?;
38        if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
39            Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array())
40        } else {
41            // The storage dtype changed (i.e., became nullable due to masking)
42            let ext_dtype = Arc::new(ExtDType::new(
43                array.ext_dtype().id().clone(),
44                Arc::new(masked_storage.dtype().clone()),
45                array.ext_dtype().metadata().cloned(),
46            ));
47            Ok(ExtensionArray::new(ext_dtype, masked_storage).into_array())
48        }
49    }
50}
51
52register_kernel!(MaskKernelAdapter(ExtensionVTable).lift());
53
54impl SumKernel for ExtensionVTable {
55    fn sum(&self, array: &ExtensionArray) -> VortexResult<Scalar> {
56        sum(array.storage())
57    }
58}
59
60register_kernel!(SumKernelAdapter(ExtensionVTable).lift());
61
62impl TakeKernel for ExtensionVTable {
63    fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
64        let taken_storage = take(array.storage(), indices)?;
65        if taken_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
66            Ok(ExtensionArray::new(array.ext_dtype().clone(), taken_storage).into_array())
67        } else {
68            // The storage dtype changed (i.e., became nullable due to nullable indices)
69            let ext_dtype = Arc::new(ExtDType::new(
70                array.ext_dtype().id().clone(),
71                Arc::new(taken_storage.dtype().clone()),
72                array.ext_dtype().metadata().cloned(),
73            ));
74            Ok(ExtensionArray::new(ext_dtype, taken_storage).into_array())
75        }
76    }
77}
78
79register_kernel!(TakeKernelAdapter(ExtensionVTable).lift());
80
81impl MinMaxKernel for ExtensionVTable {
82    fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
83        Ok(
84            min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
85                min: Scalar::extension(array.ext_dtype().clone(), min),
86                max: Scalar::extension(array.ext_dtype().clone(), max),
87            }),
88        )
89    }
90}
91
92register_kernel!(MinMaxKernelAdapter(ExtensionVTable).lift());
93
94impl IsConstantKernel for ExtensionVTable {
95    fn is_constant(
96        &self,
97        array: &ExtensionArray,
98        opts: &IsConstantOpts,
99    ) -> VortexResult<Option<bool>> {
100        is_constant_opts(array.storage(), opts)
101    }
102}
103
104register_kernel!(IsConstantKernelAdapter(ExtensionVTable).lift());
105
106impl IsSortedKernel for ExtensionVTable {
107    fn is_sorted(&self, array: &ExtensionArray) -> VortexResult<bool> {
108        is_sorted(array.storage())
109    }
110
111    fn is_strict_sorted(&self, array: &ExtensionArray) -> VortexResult<bool> {
112        is_strict_sorted(array.storage())
113    }
114}
115
116register_kernel!(IsSortedKernelAdapter(ExtensionVTable).lift());
117
118#[cfg(test)]
119mod test {
120    use std::sync::Arc;
121
122    use rstest::rstest;
123    use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
124
125    use crate::IntoArray;
126    use crate::arrays::{ExtensionArray, PrimitiveArray};
127    use crate::compute::conformance::filter::test_filter_conformance;
128    use crate::compute::conformance::take::test_take_conformance;
129
130    #[test]
131    fn test_filter_extension_array() {
132        // Create a simple extension type (e.g., UUID represented as u64)
133        let ext_dtype = ExtDType::new(
134            ExtID::new("uuid".into()),
135            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
136            None,
137        );
138
139        // Create storage array
140        let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
141        let array = ExtensionArray::new(Arc::new(ext_dtype), storage);
142        test_filter_conformance(array.as_ref());
143
144        // Test with nullable extension type
145        let ext_dtype_nullable = ExtDType::new(
146            ExtID::new("uuid".into()),
147            Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
148            None,
149        );
150        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
151            .into_array();
152        let array = ExtensionArray::new(Arc::new(ext_dtype_nullable), storage);
153        test_filter_conformance(array.as_ref());
154    }
155
156    #[rstest]
157    #[case({
158        // Simple extension type (non-nullable u64)
159        let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
160        let ext_dtype = ExtDType::new(
161            ExtID::new("uuid".into()),
162            Arc::new(storage.dtype().clone()),
163            None,
164        );
165        ExtensionArray::new(Arc::new(ext_dtype), storage)
166    })]
167    #[case({
168        // Nullable extension type
169        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
170            .into_array();
171        let ext_dtype_nullable = ExtDType::new(
172            ExtID::new("uuid".into()),
173            Arc::new(storage.dtype().clone()),
174            None,
175        );
176        ExtensionArray::new(Arc::new(ext_dtype_nullable), storage)
177    })]
178    #[case({
179        // Single element
180        let storage = PrimitiveArray::from_iter([42u64]).into_array();
181        let ext_dtype_single = ExtDType::new(
182            ExtID::new("uuid".into()),
183            Arc::new(storage.dtype().clone()),
184            None,
185        );
186        ExtensionArray::new(Arc::new(ext_dtype_single), storage)
187    })]
188    #[case({
189        // Larger array for edge cases
190        let storage = PrimitiveArray::from_iter(0u64..100).into_array();
191        let ext_dtype_large = ExtDType::new(
192            ExtID::new("uuid".into()),
193            Arc::new(storage.dtype().clone()),
194            None,
195        );
196        ExtensionArray::new(Arc::new(ext_dtype_large), storage)
197    })]
198    fn test_take_extension_array_conformance(#[case] array: ExtensionArray) {
199        test_take_conformance(array.as_ref());
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use std::sync::Arc;
206
207    use rstest::rstest;
208    use vortex_dtype::{ExtDType, ExtID};
209
210    use crate::IntoArray;
211    use crate::arrays::{ExtensionArray, PrimitiveArray};
212    use crate::compute::conformance::consistency::test_array_consistency;
213
214    #[rstest]
215    // Note: The original test_all_consistency cases for extension arrays caused errors
216    // because of unsupported extension type "uuid". We'll use simpler test cases.
217    #[case::extension_simple({
218        let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
219        let ext_dtype = ExtDType::new(
220            ExtID::new("test_ext".into()),
221            Arc::new(storage.dtype().clone()),
222            None,
223        );
224        ExtensionArray::new(Arc::new(ext_dtype), storage)
225    })]
226    #[case::extension_nullable({
227        let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
228            .into_array();
229        let ext_dtype = ExtDType::new(
230            ExtID::new("test_ext".into()),
231            Arc::new(storage.dtype().clone()),
232            None,
233        );
234        ExtensionArray::new(Arc::new(ext_dtype), storage)
235    })]
236    // Additional test cases
237    #[case::extension_single({
238        let storage = PrimitiveArray::from_iter([42i32]).into_array();
239        let ext_dtype = ExtDType::new(
240            ExtID::new("test_ext".into()),
241            Arc::new(storage.dtype().clone()),
242            None,
243        );
244        ExtensionArray::new(Arc::new(ext_dtype), storage)
245    })]
246    #[case::extension_large({
247        let storage = PrimitiveArray::from_iter(0..100i64).into_array();
248        let ext_dtype = ExtDType::new(
249            ExtID::new("test_ext".into()),
250            Arc::new(storage.dtype().clone()),
251            None,
252        );
253        ExtensionArray::new(Arc::new(ext_dtype), storage)
254    })]
255    fn test_extension_consistency(#[case] array: ExtensionArray) {
256        test_array_consistency(array.as_ref());
257    }
258}