vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
2use vortex_mask::{AllOr, Mask};
3use vortex_scalar::Scalar;
4
5use crate::arrays::{ConstantArray, ConstantEncoding};
6use crate::builders::builder_with_capacity;
7use crate::compute::TakeFn;
8use crate::{Array, ArrayRef};
9
10impl TakeFn<&ConstantArray> for ConstantEncoding {
11 fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12 match indices.validity_mask()?.boolean_buffer() {
13 AllOr::All => {
14 let nullability = array.dtype().nullability() | indices.dtype().nullability();
15 let scalar = Scalar::new(
16 array.scalar().dtype().with_nullability(nullability),
17 array.scalar().value().clone(),
18 );
19 Ok(ConstantArray::new(scalar, indices.len()).into_array())
20 }
21 AllOr::None => {
22 Ok(ConstantArray::new(
23 Scalar::null(array.dtype().with_nullability(
24 array.dtype().nullability() | indices.dtype().nullability(),
25 )),
26 indices.len(),
27 )
28 .into_array())
29 }
30 AllOr::Some(v) => {
31 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
32
33 if array.scalar().is_null() {
34 return Ok(arr);
35 }
36
37 let mut result_builder =
38 builder_with_capacity(&array.dtype().as_nullable(), indices.len());
39 result_builder.extend_from_array(&arr)?;
40 result_builder.set_validity(Mask::from_buffer(v.clone()));
41 Ok(result_builder.finish())
42 }
43 }
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use vortex_buffer::buffer;
50 use vortex_dtype::Nullability;
51 use vortex_mask::AllOr;
52
53 use crate::arrays::{ConstantArray, PrimitiveArray};
54 use crate::compute::take;
55 use crate::validity::Validity;
56 use crate::{Array, ToCanonical};
57
58 #[test]
59 fn take_nullable_indices() {
60 let array = ConstantArray::new(42, 10).to_array();
61 let taken = take(
62 &array,
63 &PrimitiveArray::new(
64 buffer![0, 5, 7],
65 Validity::from_iter(vec![false, true, false]),
66 )
67 .into_array(),
68 )
69 .unwrap();
70 let valid_indices: &[usize] = &[1usize];
71 assert_eq!(
72 &array.dtype().with_nullability(Nullability::Nullable),
73 taken.dtype()
74 );
75 assert_eq!(
76 taken.to_primitive().unwrap().as_slice::<i32>(),
77 &[42, 42, 42]
78 );
79 assert_eq!(
80 taken.validity_mask().unwrap().indices(),
81 AllOr::Some(valid_indices)
82 );
83 }
84
85 #[test]
86 fn take_all_valid_indices() {
87 let array = ConstantArray::new(42, 10).to_array();
88 let taken = take(
89 &array,
90 &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
91 )
92 .unwrap();
93 assert_eq!(
94 &array.dtype().with_nullability(Nullability::Nullable),
95 taken.dtype()
96 );
97 assert_eq!(
98 taken.to_primitive().unwrap().as_slice::<i32>(),
99 &[42, 42, 42]
100 );
101 assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
102 }
103}