vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
2use vortex_mask::{AllOr, Mask};
3use vortex_scalar::Scalar;
4
5use crate::arrays::{ConstantArray, ConstantVTable};
6use crate::builders::builder_with_capacity;
7use crate::compute::{TakeKernel, TakeKernelAdapter};
8use crate::{Array, ArrayRef, IntoArray, register_kernel};
9
10impl TakeKernel for ConstantVTable {
11 fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12 match indices.validity_mask()?.boolean_buffer() {
13 AllOr::All => {
14 let scalar = Scalar::new(
15 array
16 .scalar()
17 .dtype()
18 .union_nullability(indices.dtype().nullability()),
19 array.scalar().value().clone(),
20 );
21 Ok(ConstantArray::new(scalar, indices.len()).into_array())
22 }
23 AllOr::None => Ok(ConstantArray::new(
24 Scalar::null(
25 array
26 .dtype()
27 .union_nullability(indices.dtype().nullability()),
28 ),
29 indices.len(),
30 )
31 .into_array()),
32 AllOr::Some(v) => {
33 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
34
35 if array.scalar().is_null() {
36 return Ok(arr);
37 }
38
39 let mut result_builder =
40 builder_with_capacity(&array.dtype().as_nullable(), indices.len());
41 result_builder.extend_from_array(&arr)?;
42 result_builder.set_validity(Mask::from_buffer(v.clone()));
43 Ok(result_builder.finish())
44 }
45 }
46 }
47}
48
49register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
50
51#[cfg(test)]
52mod tests {
53 use vortex_buffer::buffer;
54 use vortex_dtype::Nullability;
55 use vortex_mask::AllOr;
56
57 use crate::arrays::{ConstantArray, PrimitiveArray};
58 use crate::compute::take;
59 use crate::validity::Validity;
60 use crate::{Array, IntoArray, ToCanonical};
61
62 #[test]
63 fn take_nullable_indices() {
64 let array = ConstantArray::new(42, 10).to_array();
65 let taken = take(
66 &array,
67 &PrimitiveArray::new(
68 buffer![0, 5, 7],
69 Validity::from_iter(vec![false, true, false]),
70 )
71 .into_array(),
72 )
73 .unwrap();
74 let valid_indices: &[usize] = &[1usize];
75 assert_eq!(
76 &array.dtype().with_nullability(Nullability::Nullable),
77 taken.dtype()
78 );
79 assert_eq!(
80 taken.to_primitive().unwrap().as_slice::<i32>(),
81 &[42, 42, 42]
82 );
83 assert_eq!(
84 taken.validity_mask().unwrap().indices(),
85 AllOr::Some(valid_indices)
86 );
87 }
88
89 #[test]
90 fn take_all_valid_indices() {
91 let array = ConstantArray::new(42, 10).to_array();
92 let taken = take(
93 &array,
94 &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
95 )
96 .unwrap();
97 assert_eq!(
98 &array.dtype().with_nullability(Nullability::Nullable),
99 taken.dtype()
100 );
101 assert_eq!(
102 taken.to_primitive().unwrap().as_slice::<i32>(),
103 &[42, 42, 42]
104 );
105 assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
106 }
107}