vortex_array/arrays/extension/compute/
mod.rs1mod cast;
5mod compare;
6
7use std::sync::Arc;
8
9use vortex_dtype::ExtDType;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12use vortex_scalar::Scalar;
13
14use crate::arrays::ExtensionVTable;
15use crate::arrays::extension::ExtensionArray;
16use crate::compute::{
17 FilterKernel, FilterKernelAdapter, IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts,
18 IsSortedKernel, IsSortedKernelAdapter, MaskKernel, MaskKernelAdapter, MinMaxKernel,
19 MinMaxKernelAdapter, MinMaxResult, SumKernel, SumKernelAdapter, TakeKernel, TakeKernelAdapter,
20 filter, is_constant_opts, is_sorted, is_strict_sorted, mask, min_max, sum, take,
21};
22use crate::{Array, ArrayRef, IntoArray, register_kernel};
23
24impl FilterKernel for ExtensionVTable {
25 fn filter(&self, array: &ExtensionArray, mask: &Mask) -> VortexResult<ArrayRef> {
26 Ok(
27 ExtensionArray::new(array.ext_dtype().clone(), filter(array.storage(), mask)?)
28 .into_array(),
29 )
30 }
31}
32
33register_kernel!(FilterKernelAdapter(ExtensionVTable).lift());
34
35impl MaskKernel for ExtensionVTable {
36 fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
37 let masked_storage = mask(array.storage(), mask_array)?;
38 if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
39 Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array())
40 } else {
41 let ext_dtype = Arc::new(ExtDType::new(
43 array.ext_dtype().id().clone(),
44 Arc::new(masked_storage.dtype().clone()),
45 array.ext_dtype().metadata().cloned(),
46 ));
47 Ok(ExtensionArray::new(ext_dtype, masked_storage).into_array())
48 }
49 }
50}
51
52register_kernel!(MaskKernelAdapter(ExtensionVTable).lift());
53
54impl SumKernel for ExtensionVTable {
55 fn sum(&self, array: &ExtensionArray) -> VortexResult<Scalar> {
56 sum(array.storage())
57 }
58}
59
60register_kernel!(SumKernelAdapter(ExtensionVTable).lift());
61
62impl TakeKernel for ExtensionVTable {
63 fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
64 let taken_storage = take(array.storage(), indices)?;
65 if taken_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
66 Ok(ExtensionArray::new(array.ext_dtype().clone(), taken_storage).into_array())
67 } else {
68 let ext_dtype = Arc::new(ExtDType::new(
70 array.ext_dtype().id().clone(),
71 Arc::new(taken_storage.dtype().clone()),
72 array.ext_dtype().metadata().cloned(),
73 ));
74 Ok(ExtensionArray::new(ext_dtype, taken_storage).into_array())
75 }
76 }
77}
78
79register_kernel!(TakeKernelAdapter(ExtensionVTable).lift());
80
81impl MinMaxKernel for ExtensionVTable {
82 fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
83 Ok(
84 min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
85 min: Scalar::extension(array.ext_dtype().clone(), min),
86 max: Scalar::extension(array.ext_dtype().clone(), max),
87 }),
88 )
89 }
90}
91
92register_kernel!(MinMaxKernelAdapter(ExtensionVTable).lift());
93
94impl IsConstantKernel for ExtensionVTable {
95 fn is_constant(
96 &self,
97 array: &ExtensionArray,
98 opts: &IsConstantOpts,
99 ) -> VortexResult<Option<bool>> {
100 is_constant_opts(array.storage(), opts)
101 }
102}
103
104register_kernel!(IsConstantKernelAdapter(ExtensionVTable).lift());
105
106impl IsSortedKernel for ExtensionVTable {
107 fn is_sorted(&self, array: &ExtensionArray) -> VortexResult<bool> {
108 is_sorted(array.storage())
109 }
110
111 fn is_strict_sorted(&self, array: &ExtensionArray) -> VortexResult<bool> {
112 is_strict_sorted(array.storage())
113 }
114}
115
116register_kernel!(IsSortedKernelAdapter(ExtensionVTable).lift());
117
118#[cfg(test)]
119mod test {
120 use std::sync::Arc;
121
122 use rstest::rstest;
123 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
124
125 use crate::IntoArray;
126 use crate::arrays::{ExtensionArray, PrimitiveArray};
127 use crate::compute::conformance::filter::test_filter_conformance;
128 use crate::compute::conformance::take::test_take_conformance;
129
130 #[test]
131 fn test_filter_extension_array() {
132 let ext_dtype = ExtDType::new(
134 ExtID::new("uuid".into()),
135 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
136 None,
137 );
138
139 let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
141 let array = ExtensionArray::new(Arc::new(ext_dtype), storage);
142 test_filter_conformance(array.as_ref());
143
144 let ext_dtype_nullable = ExtDType::new(
146 ExtID::new("uuid".into()),
147 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
148 None,
149 );
150 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
151 .into_array();
152 let array = ExtensionArray::new(Arc::new(ext_dtype_nullable), storage);
153 test_filter_conformance(array.as_ref());
154 }
155
156 #[rstest]
157 #[case({
158 let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
160 let ext_dtype = ExtDType::new(
161 ExtID::new("uuid".into()),
162 Arc::new(storage.dtype().clone()),
163 None,
164 );
165 ExtensionArray::new(Arc::new(ext_dtype), storage)
166 })]
167 #[case({
168 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
170 .into_array();
171 let ext_dtype_nullable = ExtDType::new(
172 ExtID::new("uuid".into()),
173 Arc::new(storage.dtype().clone()),
174 None,
175 );
176 ExtensionArray::new(Arc::new(ext_dtype_nullable), storage)
177 })]
178 #[case({
179 let storage = PrimitiveArray::from_iter([42u64]).into_array();
181 let ext_dtype_single = ExtDType::new(
182 ExtID::new("uuid".into()),
183 Arc::new(storage.dtype().clone()),
184 None,
185 );
186 ExtensionArray::new(Arc::new(ext_dtype_single), storage)
187 })]
188 #[case({
189 let storage = PrimitiveArray::from_iter(0u64..100).into_array();
191 let ext_dtype_large = ExtDType::new(
192 ExtID::new("uuid".into()),
193 Arc::new(storage.dtype().clone()),
194 None,
195 );
196 ExtensionArray::new(Arc::new(ext_dtype_large), storage)
197 })]
198 fn test_take_extension_array_conformance(#[case] array: ExtensionArray) {
199 test_take_conformance(array.as_ref());
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use std::sync::Arc;
206
207 use rstest::rstest;
208 use vortex_dtype::{ExtDType, ExtID};
209
210 use crate::IntoArray;
211 use crate::arrays::{ExtensionArray, PrimitiveArray};
212 use crate::compute::conformance::consistency::test_array_consistency;
213
214 #[rstest]
215 #[case::extension_simple({
218 let storage = PrimitiveArray::from_iter([1u64, 2, 3, 4, 5]).into_array();
219 let ext_dtype = ExtDType::new(
220 ExtID::new("test_ext".into()),
221 Arc::new(storage.dtype().clone()),
222 None,
223 );
224 ExtensionArray::new(Arc::new(ext_dtype), storage)
225 })]
226 #[case::extension_nullable({
227 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
228 .into_array();
229 let ext_dtype = ExtDType::new(
230 ExtID::new("test_ext".into()),
231 Arc::new(storage.dtype().clone()),
232 None,
233 );
234 ExtensionArray::new(Arc::new(ext_dtype), storage)
235 })]
236 #[case::extension_single({
238 let storage = PrimitiveArray::from_iter([42i32]).into_array();
239 let ext_dtype = ExtDType::new(
240 ExtID::new("test_ext".into()),
241 Arc::new(storage.dtype().clone()),
242 None,
243 );
244 ExtensionArray::new(Arc::new(ext_dtype), storage)
245 })]
246 #[case::extension_large({
247 let storage = PrimitiveArray::from_iter(0..100i64).into_array();
248 let ext_dtype = ExtDType::new(
249 ExtID::new("test_ext".into()),
250 Arc::new(storage.dtype().clone()),
251 None,
252 );
253 ExtensionArray::new(Arc::new(ext_dtype), storage)
254 })]
255 fn test_extension_consistency(#[case] array: ExtensionArray) {
256 test_array_consistency(array.as_ref());
257 }
258}