vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6use vortex_scalar::Scalar;
7
8use crate::arrays::{ConstantArray, ConstantVTable, MaskedArray};
9use crate::compute::{TakeKernel, TakeKernelAdapter};
10use crate::validity::Validity;
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl TakeKernel for ConstantVTable {
14 fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
15 match indices.validity_mask().boolean_buffer() {
16 AllOr::All => {
17 let scalar = Scalar::new(
18 array
19 .scalar()
20 .dtype()
21 .union_nullability(indices.dtype().nullability()),
22 array.scalar().value().clone(),
23 );
24 Ok(ConstantArray::new(scalar, indices.len()).into_array())
25 }
26 AllOr::None => Ok(ConstantArray::new(
27 Scalar::null(
28 array
29 .dtype()
30 .union_nullability(indices.dtype().nullability()),
31 ),
32 indices.len(),
33 )
34 .into_array()),
35 AllOr::Some(v) => {
36 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
37
38 if array.scalar().is_null() {
39 return Ok(arr);
40 }
41
42 Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array())
43 }
44 }
45 }
46}
47
48register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
49
50#[cfg(test)]
51mod tests {
52 use rstest::rstest;
53 use vortex_buffer::buffer;
54 use vortex_dtype::Nullability;
55 use vortex_mask::AllOr;
56 use vortex_scalar::Scalar;
57
58 use crate::arrays::{ConstantArray, PrimitiveArray};
59 use crate::compute::conformance::take::test_take_conformance;
60 use crate::compute::take;
61 use crate::validity::Validity;
62 use crate::{Array, IntoArray, ToCanonical};
63
64 #[test]
65 fn take_nullable_indices() {
66 let array = ConstantArray::new(42, 10).to_array();
67 let taken = take(
68 &array,
69 &PrimitiveArray::new(
70 buffer![0, 5, 7],
71 Validity::from_iter(vec![false, true, false]),
72 )
73 .into_array(),
74 )
75 .unwrap();
76 let valid_indices: &[usize] = &[1usize];
77 assert_eq!(
78 &array.dtype().with_nullability(Nullability::Nullable),
79 taken.dtype()
80 );
81 assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
82 assert_eq!(taken.validity_mask().indices(), AllOr::Some(valid_indices));
83 }
84
85 #[test]
86 fn take_all_valid_indices() {
87 let array = ConstantArray::new(42, 10).to_array();
88 let taken = take(
89 &array,
90 &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
91 )
92 .unwrap();
93 assert_eq!(
94 &array.dtype().with_nullability(Nullability::Nullable),
95 taken.dtype()
96 );
97 assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
98 assert_eq!(taken.validity_mask().indices(), AllOr::All);
99 }
100
101 #[rstest]
102 #[case(ConstantArray::new(42i32, 5))]
103 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
104 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
105 #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
106 #[case(ConstantArray::new(true, 1))]
107 fn test_take_constant_conformance(#[case] array: ConstantArray) {
108 test_take_conformance(array.as_ref());
109 }
110}