vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::{AllOr, Mask};
6use vortex_scalar::Scalar;
7
8use crate::arrays::{ConstantArray, ConstantVTable};
9use crate::builders::builder_with_capacity;
10use crate::compute::{TakeKernel, TakeKernelAdapter};
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl TakeKernel for ConstantVTable {
14 fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
15 match indices.validity_mask()?.boolean_buffer() {
16 AllOr::All => {
17 let scalar = Scalar::new(
18 array
19 .scalar()
20 .dtype()
21 .union_nullability(indices.dtype().nullability()),
22 array.scalar().value().clone(),
23 );
24 Ok(ConstantArray::new(scalar, indices.len()).into_array())
25 }
26 AllOr::None => Ok(ConstantArray::new(
27 Scalar::null(
28 array
29 .dtype()
30 .union_nullability(indices.dtype().nullability()),
31 ),
32 indices.len(),
33 )
34 .into_array()),
35 AllOr::Some(v) => {
36 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
37
38 if array.scalar().is_null() {
39 return Ok(arr);
40 }
41
42 let mut result_builder =
43 builder_with_capacity(&array.dtype().as_nullable(), indices.len());
44 result_builder.extend_from_array(&arr)?;
45 result_builder.set_validity(Mask::from_buffer(v.clone()));
46 Ok(result_builder.finish())
47 }
48 }
49 }
50}
51
52register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
53
54#[cfg(test)]
55mod tests {
56 use rstest::rstest;
57 use vortex_buffer::buffer;
58 use vortex_dtype::Nullability;
59 use vortex_mask::AllOr;
60 use vortex_scalar::Scalar;
61
62 use crate::arrays::{ConstantArray, PrimitiveArray};
63 use crate::compute::conformance::take::test_take_conformance;
64 use crate::compute::take;
65 use crate::validity::Validity;
66 use crate::{Array, IntoArray, ToCanonical};
67
68 #[test]
69 fn take_nullable_indices() {
70 let array = ConstantArray::new(42, 10).to_array();
71 let taken = take(
72 &array,
73 &PrimitiveArray::new(
74 buffer![0, 5, 7],
75 Validity::from_iter(vec![false, true, false]),
76 )
77 .into_array(),
78 )
79 .unwrap();
80 let valid_indices: &[usize] = &[1usize];
81 assert_eq!(
82 &array.dtype().with_nullability(Nullability::Nullable),
83 taken.dtype()
84 );
85 assert_eq!(
86 taken.to_primitive().unwrap().as_slice::<i32>(),
87 &[42, 42, 42]
88 );
89 assert_eq!(
90 taken.validity_mask().unwrap().indices(),
91 AllOr::Some(valid_indices)
92 );
93 }
94
95 #[test]
96 fn take_all_valid_indices() {
97 let array = ConstantArray::new(42, 10).to_array();
98 let taken = take(
99 &array,
100 &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
101 )
102 .unwrap();
103 assert_eq!(
104 &array.dtype().with_nullability(Nullability::Nullable),
105 taken.dtype()
106 );
107 assert_eq!(
108 taken.to_primitive().unwrap().as_slice::<i32>(),
109 &[42, 42, 42]
110 );
111 assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
112 }
113
114 #[rstest]
115 #[case(ConstantArray::new(42i32, 5))]
116 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
117 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
118 #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
119 #[case(ConstantArray::new(true, 1))]
120 fn test_take_constant_conformance(#[case] array: ConstantArray) {
121 test_take_conformance(array.as_ref());
122 }
123}