vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6use vortex_scalar::Scalar;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::arrays::ConstantVTable;
13use crate::arrays::MaskedArray;
14use crate::compute::TakeKernel;
15use crate::compute::TakeKernelAdapter;
16use crate::register_kernel;
17use crate::validity::Validity;
18
19impl TakeKernel for ConstantVTable {
20    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21        match indices.validity_mask().bit_buffer() {
22            AllOr::All => {
23                let scalar = Scalar::new(
24                    array
25                        .scalar()
26                        .dtype()
27                        .union_nullability(indices.dtype().nullability()),
28                    array.scalar().value().clone(),
29                );
30                Ok(ConstantArray::new(scalar, indices.len()).into_array())
31            }
32            AllOr::None => Ok(ConstantArray::new(
33                Scalar::null(
34                    array
35                        .dtype()
36                        .union_nullability(indices.dtype().nullability()),
37                ),
38                indices.len(),
39            )
40            .into_array()),
41            AllOr::Some(v) => {
42                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
43
44                if array.scalar().is_null() {
45                    return Ok(arr);
46                }
47
48                Ok(MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array())
49            }
50        }
51    }
52}
53
54register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
55
56#[cfg(test)]
57mod tests {
58    use rstest::rstest;
59    use vortex_buffer::buffer;
60    use vortex_dtype::Nullability;
61    use vortex_mask::AllOr;
62    use vortex_scalar::Scalar;
63
64    use crate::Array;
65    use crate::IntoArray;
66    use crate::ToCanonical;
67    use crate::arrays::ConstantArray;
68    use crate::arrays::PrimitiveArray;
69    use crate::compute::conformance::take::test_take_conformance;
70    use crate::compute::take;
71    use crate::validity::Validity;
72
73    #[test]
74    fn take_nullable_indices() {
75        let array = ConstantArray::new(42, 10).to_array();
76        let taken = take(
77            &array,
78            &PrimitiveArray::new(
79                buffer![0, 5, 7],
80                Validity::from_iter(vec![false, true, false]),
81            )
82            .into_array(),
83        )
84        .unwrap();
85        let valid_indices: &[usize] = &[1usize];
86        assert_eq!(
87            &array.dtype().with_nullability(Nullability::Nullable),
88            taken.dtype()
89        );
90        assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
91        assert_eq!(taken.validity_mask().indices(), AllOr::Some(valid_indices));
92    }
93
94    #[test]
95    fn take_all_valid_indices() {
96        let array = ConstantArray::new(42, 10).to_array();
97        let taken = take(
98            &array,
99            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
100        )
101        .unwrap();
102        assert_eq!(
103            &array.dtype().with_nullability(Nullability::Nullable),
104            taken.dtype()
105        );
106        assert_eq!(taken.to_primitive().as_slice::<i32>(), &[42, 42, 42]);
107        assert_eq!(taken.validity_mask().indices(), AllOr::All);
108    }
109
110    #[rstest]
111    #[case(ConstantArray::new(42i32, 5))]
112    #[case(ConstantArray::new(std::f64::consts::PI, 10))]
113    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
114    #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
115    #[case(ConstantArray::new(true, 1))]
116    fn test_take_constant_conformance(#[case] array: ConstantArray) {
117        test_take_conformance(array.as_ref());
118    }
119}