vortex_array/arrays/constant/compute/
take.rs

1use vortex_error::VortexResult;
2use vortex_mask::{AllOr, Mask};
3use vortex_scalar::Scalar;
4
5use crate::arrays::{ConstantArray, ConstantVTable};
6use crate::builders::builder_with_capacity;
7use crate::compute::{TakeKernel, TakeKernelAdapter};
8use crate::{Array, ArrayRef, IntoArray, register_kernel};
9
10impl TakeKernel for ConstantVTable {
11    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12        match indices.validity_mask()?.boolean_buffer() {
13            AllOr::All => {
14                let scalar = Scalar::new(
15                    array
16                        .scalar()
17                        .dtype()
18                        .union_nullability(indices.dtype().nullability()),
19                    array.scalar().value().clone(),
20                );
21                Ok(ConstantArray::new(scalar, indices.len()).into_array())
22            }
23            AllOr::None => Ok(ConstantArray::new(
24                Scalar::null(
25                    array
26                        .dtype()
27                        .union_nullability(indices.dtype().nullability()),
28                ),
29                indices.len(),
30            )
31            .into_array()),
32            AllOr::Some(v) => {
33                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
34
35                if array.scalar().is_null() {
36                    return Ok(arr);
37                }
38
39                let mut result_builder =
40                    builder_with_capacity(&array.dtype().as_nullable(), indices.len());
41                result_builder.extend_from_array(&arr)?;
42                result_builder.set_validity(Mask::from_buffer(v.clone()));
43                Ok(result_builder.finish())
44            }
45        }
46    }
47}
48
49register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
50
51#[cfg(test)]
52mod tests {
53    use vortex_buffer::buffer;
54    use vortex_dtype::Nullability;
55    use vortex_mask::AllOr;
56
57    use crate::arrays::{ConstantArray, PrimitiveArray};
58    use crate::compute::take;
59    use crate::validity::Validity;
60    use crate::{Array, IntoArray, ToCanonical};
61
62    #[test]
63    fn take_nullable_indices() {
64        let array = ConstantArray::new(42, 10).to_array();
65        let taken = take(
66            &array,
67            &PrimitiveArray::new(
68                buffer![0, 5, 7],
69                Validity::from_iter(vec![false, true, false]),
70            )
71            .into_array(),
72        )
73        .unwrap();
74        let valid_indices: &[usize] = &[1usize];
75        assert_eq!(
76            &array.dtype().with_nullability(Nullability::Nullable),
77            taken.dtype()
78        );
79        assert_eq!(
80            taken.to_primitive().unwrap().as_slice::<i32>(),
81            &[42, 42, 42]
82        );
83        assert_eq!(
84            taken.validity_mask().unwrap().indices(),
85            AllOr::Some(valid_indices)
86        );
87    }
88
89    #[test]
90    fn take_all_valid_indices() {
91        let array = ConstantArray::new(42, 10).to_array();
92        let taken = take(
93            &array,
94            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
95        )
96        .unwrap();
97        assert_eq!(
98            &array.dtype().with_nullability(Nullability::Nullable),
99            taken.dtype()
100        );
101        assert_eq!(
102            taken.to_primitive().unwrap().as_slice::<i32>(),
103            &[42, 42, 42]
104        );
105        assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
106    }
107}