vortex_array/arrays/extension/compute/
mod.rs1mod cast;
5mod compare;
6
7use std::sync::Arc;
8
9use vortex_dtype::ExtDType;
10use vortex_error::VortexResult;
11use vortex_mask::Mask;
12use vortex_scalar::Scalar;
13
14use crate::arrays::ExtensionVTable;
15use crate::arrays::extension::ExtensionArray;
16use crate::compute::{
17 FilterKernel, FilterKernelAdapter, IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts,
18 IsSortedKernel, IsSortedKernelAdapter, MaskKernel, MaskKernelAdapter, MinMaxKernel,
19 MinMaxKernelAdapter, MinMaxResult, SumKernel, SumKernelAdapter, TakeKernel, TakeKernelAdapter,
20 filter, is_constant_opts, is_sorted, is_strict_sorted, mask, min_max, sum, take,
21};
22use crate::{Array, ArrayRef, IntoArray, register_kernel};
23
24impl FilterKernel for ExtensionVTable {
25 fn filter(&self, array: &ExtensionArray, mask: &Mask) -> VortexResult<ArrayRef> {
26 Ok(
27 ExtensionArray::new(array.ext_dtype().clone(), filter(array.storage(), mask)?)
28 .into_array(),
29 )
30 }
31}
32
33register_kernel!(FilterKernelAdapter(ExtensionVTable).lift());
34
35impl MaskKernel for ExtensionVTable {
36 fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
37 let masked_storage = mask(array.storage(), mask_array)?;
38 if masked_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
39 Ok(ExtensionArray::new(array.ext_dtype().clone(), masked_storage).into_array())
40 } else {
41 let ext_dtype = Arc::new(ExtDType::new(
43 array.ext_dtype().id().clone(),
44 Arc::new(masked_storage.dtype().clone()),
45 array.ext_dtype().metadata().cloned(),
46 ));
47 Ok(ExtensionArray::new(ext_dtype, masked_storage).into_array())
48 }
49 }
50}
51
52register_kernel!(MaskKernelAdapter(ExtensionVTable).lift());
53
54impl SumKernel for ExtensionVTable {
55 fn sum(&self, array: &ExtensionArray) -> VortexResult<Scalar> {
56 sum(array.storage())
57 }
58}
59
60register_kernel!(SumKernelAdapter(ExtensionVTable).lift());
61
62impl TakeKernel for ExtensionVTable {
63 fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
64 let taken_storage = take(array.storage(), indices)?;
65 if taken_storage.dtype().nullability() == array.ext_dtype().storage_dtype().nullability() {
66 Ok(ExtensionArray::new(array.ext_dtype().clone(), taken_storage).into_array())
67 } else {
68 let ext_dtype = Arc::new(ExtDType::new(
70 array.ext_dtype().id().clone(),
71 Arc::new(taken_storage.dtype().clone()),
72 array.ext_dtype().metadata().cloned(),
73 ));
74 Ok(ExtensionArray::new(ext_dtype, taken_storage).into_array())
75 }
76 }
77}
78
79register_kernel!(TakeKernelAdapter(ExtensionVTable).lift());
80
81impl MinMaxKernel for ExtensionVTable {
82 fn min_max(&self, array: &ExtensionArray) -> VortexResult<Option<MinMaxResult>> {
83 Ok(
84 min_max(array.storage())?.map(|MinMaxResult { min, max }| MinMaxResult {
85 min: Scalar::extension(array.ext_dtype().clone(), min),
86 max: Scalar::extension(array.ext_dtype().clone(), max),
87 }),
88 )
89 }
90}
91
92register_kernel!(MinMaxKernelAdapter(ExtensionVTable).lift());
93
94impl IsConstantKernel for ExtensionVTable {
95 fn is_constant(
96 &self,
97 array: &ExtensionArray,
98 opts: &IsConstantOpts,
99 ) -> VortexResult<Option<bool>> {
100 is_constant_opts(array.storage(), opts)
101 }
102}
103
104register_kernel!(IsConstantKernelAdapter(ExtensionVTable).lift());
105
106impl IsSortedKernel for ExtensionVTable {
107 fn is_sorted(&self, array: &ExtensionArray) -> VortexResult<Option<bool>> {
108 is_sorted(array.storage())
109 }
110
111 fn is_strict_sorted(&self, array: &ExtensionArray) -> VortexResult<Option<bool>> {
112 is_strict_sorted(array.storage())
113 }
114}
115
116register_kernel!(IsSortedKernelAdapter(ExtensionVTable).lift());
117
118#[cfg(test)]
119mod test {
120 use std::sync::Arc;
121
122 use rstest::rstest;
123 use vortex_buffer::buffer;
124 use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType};
125
126 use crate::IntoArray;
127 use crate::arrays::{ExtensionArray, PrimitiveArray};
128 use crate::compute::conformance::filter::test_filter_conformance;
129 use crate::compute::conformance::take::test_take_conformance;
130
131 #[test]
132 fn test_filter_extension_array() {
133 let ext_dtype = ExtDType::new(
135 ExtID::new("uuid".into()),
136 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
137 None,
138 );
139
140 let storage = buffer![1u64, 2, 3, 4, 5].into_array();
142 let array = ExtensionArray::new(Arc::new(ext_dtype), storage);
143 test_filter_conformance(array.as_ref());
144
145 let ext_dtype_nullable = ExtDType::new(
147 ExtID::new("uuid".into()),
148 Arc::new(DType::Primitive(PType::U64, Nullability::Nullable)),
149 None,
150 );
151 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
152 .into_array();
153 let array = ExtensionArray::new(Arc::new(ext_dtype_nullable), storage);
154 test_filter_conformance(array.as_ref());
155 }
156
157 #[rstest]
158 #[case({
159 let storage = buffer![1u64, 2, 3, 4, 5].into_array();
161 let ext_dtype = ExtDType::new(
162 ExtID::new("uuid".into()),
163 Arc::new(storage.dtype().clone()),
164 None,
165 );
166 ExtensionArray::new(Arc::new(ext_dtype), storage)
167 })]
168 #[case({
169 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
171 .into_array();
172 let ext_dtype_nullable = ExtDType::new(
173 ExtID::new("uuid".into()),
174 Arc::new(storage.dtype().clone()),
175 None,
176 );
177 ExtensionArray::new(Arc::new(ext_dtype_nullable), storage)
178 })]
179 #[case({
180 let storage = buffer![42u64].into_array();
182 let ext_dtype_single = ExtDType::new(
183 ExtID::new("uuid".into()),
184 Arc::new(storage.dtype().clone()),
185 None,
186 );
187 ExtensionArray::new(Arc::new(ext_dtype_single), storage)
188 })]
189 #[case({
190 let storage = buffer![0u64..100].into_array();
192 let ext_dtype_large = ExtDType::new(
193 ExtID::new("uuid".into()),
194 Arc::new(storage.dtype().clone()),
195 None,
196 );
197 ExtensionArray::new(Arc::new(ext_dtype_large), storage)
198 })]
199 fn test_take_extension_array_conformance(#[case] array: ExtensionArray) {
200 test_take_conformance(array.as_ref());
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::sync::Arc;
207
208 use rstest::rstest;
209 use vortex_buffer::buffer;
210 use vortex_dtype::{ExtDType, ExtID};
211
212 use crate::IntoArray;
213 use crate::arrays::{ExtensionArray, PrimitiveArray};
214 use crate::compute::conformance::consistency::test_array_consistency;
215
216 #[rstest]
217 #[case::extension_simple({
220 let storage = buffer![1u64, 2, 3, 4, 5].into_array();
221 let ext_dtype = ExtDType::new(
222 ExtID::new("test_ext".into()),
223 Arc::new(storage.dtype().clone()),
224 None,
225 );
226 ExtensionArray::new(Arc::new(ext_dtype), storage)
227 })]
228 #[case::extension_nullable({
229 let storage = PrimitiveArray::from_option_iter([Some(1u64), None, Some(3), Some(4), None])
230 .into_array();
231 let ext_dtype = ExtDType::new(
232 ExtID::new("test_ext".into()),
233 Arc::new(storage.dtype().clone()),
234 None,
235 );
236 ExtensionArray::new(Arc::new(ext_dtype), storage)
237 })]
238 #[case::extension_single({
240 let storage = buffer![42i32].into_array();
241 let ext_dtype = ExtDType::new(
242 ExtID::new("test_ext".into()),
243 Arc::new(storage.dtype().clone()),
244 None,
245 );
246 ExtensionArray::new(Arc::new(ext_dtype), storage)
247 })]
248 #[case::extension_large({
249 let storage = buffer![0..100i64].into_array();
250 let ext_dtype = ExtDType::new(
251 ExtID::new("test_ext".into()),
252 Arc::new(storage.dtype().clone()),
253 None,
254 );
255 ExtensionArray::new(Arc::new(ext_dtype), storage)
256 })]
257 fn test_extension_consistency(#[case] array: ExtensionArray) {
258 test_array_consistency(array.as_ref());
259 }
260}