vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6use vortex_scalar::Scalar;
7
8use crate::arrays::{ConstantArray, ConstantVTable, MaskedArray};
9use crate::compute::{TakeKernel, TakeKernelAdapter};
10use crate::validity::Validity;
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl TakeKernel for ConstantVTable {
14    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
15        match indices.validity_mask().boolean_buffer() {
16            AllOr::All => {
17                let scalar = Scalar::new(
18                    array
19                        .scalar()
20                        .dtype()
21                        .union_nullability(indices.dtype().nullability()),
22                    array.scalar().value().clone(),
23                );
24                Ok(ConstantArray::new(scalar, indices.len()).into_array())
25            }
26            AllOr::None => Ok(ConstantArray::new(
27                Scalar::null(
28                    array
29                        .dtype()
30                        .union_nullability(indices.dtype().nullability()),
31                ),
32                indices.len(),
33            )
34            .into_array()),
35            AllOr::Some(v) => {
36                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
37
38                if array.scalar().is_null() {
39                    return Ok(arr);
40                }
41
42                Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array())
43            }
44        }
45    }
46}
47
48register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
49
50#[cfg(test)]
51mod tests {
52    use rstest::rstest;
53    use vortex_buffer::buffer;
54    use vortex_dtype::Nullability;
55    use vortex_mask::AllOr;
56    use vortex_scalar::Scalar;
57
58    use crate::arrays::{ConstantArray, PrimitiveArray};
59    use crate::compute::conformance::take::test_take_conformance;
60    use crate::compute::take;
61    use crate::validity::Validity;
62    use crate::{Array, IntoArray, ToCanonical};
63
64    #[test]
65    fn take_nullable_indices() {
66        let array = ConstantArray::new(42, 10).to_array();
67        let taken = take(
68            &array,
69            &PrimitiveArray::new(
70                buffer![0, 5, 7],
71                Validity::from_iter(vec![false, true, false]),
72            )
73            .into_array(),
74        )
75        .unwrap();
76        let valid_indices: &[usize] = &[1usize];
77        assert_eq!(
78            &array.dtype().with_nullability(Nullability::Nullable),
79            taken.dtype()
80        );
81        assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
82        assert_eq!(taken.validity_mask().indices(), AllOr::Some(valid_indices));
83    }
84
85    #[test]
86    fn take_all_valid_indices() {
87        let array = ConstantArray::new(42, 10).to_array();
88        let taken = take(
89            &array,
90            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
91        )
92        .unwrap();
93        assert_eq!(
94            &array.dtype().with_nullability(Nullability::Nullable),
95            taken.dtype()
96        );
97        assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
98        assert_eq!(taken.validity_mask().indices(), AllOr::All);
99    }
100
101    #[rstest]
102    #[case(ConstantArray::new(42i32, 5))]
103    #[case(ConstantArray::new(std::f64::consts::PI, 10))]
104    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
105    #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
106    #[case(ConstantArray::new(true, 1))]
107    fn test_take_constant_conformance(#[case] array: ConstantArray) {
108        test_take_conformance(array.as_ref());
109    }
110}