vortex_array/arrays/constant/compute/
take.rs

1use vortex_error::VortexResult;
2use vortex_mask::{AllOr, Mask};
3use vortex_scalar::Scalar;
4
5use crate::arrays::{ConstantArray, ConstantEncoding};
6use crate::builders::builder_with_capacity;
7use crate::compute::TakeFn;
8use crate::{Array, ArrayRef};
9
10impl TakeFn<&ConstantArray> for ConstantEncoding {
11    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12        match indices.validity_mask()?.boolean_buffer() {
13            AllOr::All => {
14                Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
15            }
16            AllOr::None => Ok(ConstantArray::new(
17                Scalar::null(array.dtype().clone()),
18                indices.len(),
19            )
20            .into_array()),
21            AllOr::Some(v) => {
22                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
23
24                if array.scalar().is_null() {
25                    return Ok(arr);
26                }
27
28                let mut result_builder =
29                    builder_with_capacity(&array.dtype().as_nullable(), indices.len());
30                result_builder.extend_from_array(&arr)?;
31                result_builder.set_validity(Mask::from_buffer(v.clone()));
32                Ok(result_builder.finish())
33            }
34        }
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use vortex_buffer::buffer;
41    use vortex_mask::AllOr;
42
43    use crate::arrays::{ConstantArray, PrimitiveArray};
44    use crate::compute::take;
45    use crate::validity::Validity;
46    use crate::{Array, ToCanonical};
47
48    #[test]
49    fn take_nullable_indices() {
50        let array = ConstantArray::new(42, 10).to_array();
51        let taken = take(
52            &array,
53            &PrimitiveArray::new(
54                buffer![0, 5, 7],
55                Validity::from_iter(vec![false, true, false]),
56            )
57            .into_array(),
58        )
59        .unwrap();
60        let valid_indices: &[usize] = &[1usize];
61        assert_eq!(
62            taken.to_primitive().unwrap().as_slice::<i32>(),
63            &[42, 42, 42]
64        );
65        assert_eq!(
66            taken.validity_mask().unwrap().indices(),
67            AllOr::Some(valid_indices)
68        );
69    }
70}