vortex_array/arrays/constant/compute/
take.rs

1use vortex_error::VortexResult;
2use vortex_mask::{AllOr, Mask};
3use vortex_scalar::Scalar;
4
5use crate::arrays::{ConstantArray, ConstantEncoding};
6use crate::builders::builder_with_capacity;
7use crate::compute::TakeFn;
8use crate::{Array, ArrayRef};
9
10impl TakeFn<&ConstantArray> for ConstantEncoding {
11    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12        match indices.validity_mask()?.boolean_buffer() {
13            AllOr::All => {
14                let nullability = array.dtype().nullability() | indices.dtype().nullability();
15                let scalar = Scalar::new(
16                    array.scalar().dtype().with_nullability(nullability),
17                    array.scalar().value().clone(),
18                );
19                Ok(ConstantArray::new(scalar, indices.len()).into_array())
20            }
21            AllOr::None => {
22                Ok(ConstantArray::new(
23                    Scalar::null(array.dtype().with_nullability(
24                        array.dtype().nullability() | indices.dtype().nullability(),
25                    )),
26                    indices.len(),
27                )
28                .into_array())
29            }
30            AllOr::Some(v) => {
31                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
32
33                if array.scalar().is_null() {
34                    return Ok(arr);
35                }
36
37                let mut result_builder =
38                    builder_with_capacity(&array.dtype().as_nullable(), indices.len());
39                result_builder.extend_from_array(&arr)?;
40                result_builder.set_validity(Mask::from_buffer(v.clone()));
41                Ok(result_builder.finish())
42            }
43        }
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use vortex_buffer::buffer;
50    use vortex_dtype::Nullability;
51    use vortex_mask::AllOr;
52
53    use crate::arrays::{ConstantArray, PrimitiveArray};
54    use crate::compute::take;
55    use crate::validity::Validity;
56    use crate::{Array, ToCanonical};
57
58    #[test]
59    fn take_nullable_indices() {
60        let array = ConstantArray::new(42, 10).to_array();
61        let taken = take(
62            &array,
63            &PrimitiveArray::new(
64                buffer![0, 5, 7],
65                Validity::from_iter(vec![false, true, false]),
66            )
67            .into_array(),
68        )
69        .unwrap();
70        let valid_indices: &[usize] = &[1usize];
71        assert_eq!(
72            &array.dtype().with_nullability(Nullability::Nullable),
73            taken.dtype()
74        );
75        assert_eq!(
76            taken.to_primitive().unwrap().as_slice::<i32>(),
77            &[42, 42, 42]
78        );
79        assert_eq!(
80            taken.validity_mask().unwrap().indices(),
81            AllOr::Some(valid_indices)
82        );
83    }
84
85    #[test]
86    fn take_all_valid_indices() {
87        let array = ConstantArray::new(42, 10).to_array();
88        let taken = take(
89            &array,
90            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
91        )
92        .unwrap();
93        assert_eq!(
94            &array.dtype().with_nullability(Nullability::Nullable),
95            taken.dtype()
96        );
97        assert_eq!(
98            taken.to_primitive().unwrap().as_slice::<i32>(),
99            &[42, 42, 42]
100        );
101        assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
102    }
103}