vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::arrays::ConstantVTable;
13use crate::arrays::MaskedArray;
14use crate::compute::TakeKernel;
15use crate::compute::TakeKernelAdapter;
16use crate::register_kernel;
17use crate::validity::Validity;
18
19impl TakeKernel for ConstantVTable {
20 fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21 match indices.validity_mask().bit_buffer() {
22 AllOr::All => {
23 let scalar = Scalar::new(
24 array
25 .scalar()
26 .dtype()
27 .union_nullability(indices.dtype().nullability()),
28 array.scalar().value().clone(),
29 );
30 Ok(ConstantArray::new(scalar, indices.len()).into_array())
31 }
32 AllOr::None => Ok(ConstantArray::new(
33 Scalar::null(
34 array
35 .dtype()
36 .union_nullability(indices.dtype().nullability()),
37 ),
38 indices.len(),
39 )
40 .into_array()),
41 AllOr::Some(v) => {
42 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
43
44 if array.scalar().is_null() {
45 return Ok(arr);
46 }
47
48 Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array())
49 }
50 }
51 }
52}
53
54register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
55
56#[cfg(test)]
57mod tests {
58 use rstest::rstest;
59 use vortex_buffer::buffer;
60 use vortex_dtype::Nullability;
61 use vortex_mask::AllOr;
62 use vortex_scalar::Scalar;
63
64 use crate::Array;
65 use crate::IntoArray;
66 use crate::ToCanonical;
67 use crate::arrays::ConstantArray;
68 use crate::arrays::PrimitiveArray;
69 use crate::compute::conformance::take::test_take_conformance;
70 use crate::compute::take;
71 use crate::validity::Validity;
72
73 #[test]
74 fn take_nullable_indices() {
75 let array = ConstantArray::new(42, 10).to_array();
76 let taken = take(
77 &array,
78 &PrimitiveArray::new(
79 buffer![0, 5, 7],
80 Validity::from_iter(vec![false, true, false]),
81 )
82 .into_array(),
83 )
84 .unwrap();
85 let valid_indices: &[usize] = &[1usize];
86 assert_eq!(
87 &array.dtype().with_nullability(Nullability::Nullable),
88 taken.dtype()
89 );
90 assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
91 assert_eq!(taken.validity_mask().indices(), AllOr::Some(valid_indices));
92 }
93
94 #[test]
95 fn take_all_valid_indices() {
96 let array = ConstantArray::new(42, 10).to_array();
97 let taken = take(
98 &array,
99 &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
100 )
101 .unwrap();
102 assert_eq!(
103 &array.dtype().with_nullability(Nullability::Nullable),
104 taken.dtype()
105 );
106 assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
107 assert_eq!(taken.validity_mask().indices(), AllOr::All);
108 }
109
110 #[rstest]
111 #[case(ConstantArray::new(42i32, 5))]
112 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
113 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
114 #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
115 #[case(ConstantArray::new(true, 1))]
116 fn test_take_constant_conformance(#[case] array: ConstantArray) {
117 test_take_conformance(array.as_ref());
118 }
119}