vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::{AllOr, Mask};
6use vortex_scalar::Scalar;
7
8use crate::arrays::{ConstantArray, ConstantVTable};
9use crate::builders::builder_with_capacity;
10use crate::compute::{TakeKernel, TakeKernelAdapter};
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl TakeKernel for ConstantVTable {
14    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
15        match indices.validity_mask()?.boolean_buffer() {
16            AllOr::All => {
17                let scalar = Scalar::new(
18                    array
19                        .scalar()
20                        .dtype()
21                        .union_nullability(indices.dtype().nullability()),
22                    array.scalar().value().clone(),
23                );
24                Ok(ConstantArray::new(scalar, indices.len()).into_array())
25            }
26            AllOr::None => Ok(ConstantArray::new(
27                Scalar::null(
28                    array
29                        .dtype()
30                        .union_nullability(indices.dtype().nullability()),
31                ),
32                indices.len(),
33            )
34            .into_array()),
35            AllOr::Some(v) => {
36                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
37
38                if array.scalar().is_null() {
39                    return Ok(arr);
40                }
41
42                let mut result_builder =
43                    builder_with_capacity(&array.dtype().as_nullable(), indices.len());
44                result_builder.extend_from_array(&arr)?;
45                result_builder.set_validity(Mask::from_buffer(v.clone()));
46                Ok(result_builder.finish())
47            }
48        }
49    }
50}
51
52register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
53
54#[cfg(test)]
55mod tests {
56    use rstest::rstest;
57    use vortex_buffer::buffer;
58    use vortex_dtype::Nullability;
59    use vortex_mask::AllOr;
60    use vortex_scalar::Scalar;
61
62    use crate::arrays::{ConstantArray, PrimitiveArray};
63    use crate::compute::conformance::take::test_take_conformance;
64    use crate::compute::take;
65    use crate::validity::Validity;
66    use crate::{Array, IntoArray, ToCanonical};
67
68    #[test]
69    fn take_nullable_indices() {
70        let array = ConstantArray::new(42, 10).to_array();
71        let taken = take(
72            &array,
73            &PrimitiveArray::new(
74                buffer![0, 5, 7],
75                Validity::from_iter(vec![false, true, false]),
76            )
77            .into_array(),
78        )
79        .unwrap();
80        let valid_indices: &[usize] = &[1usize];
81        assert_eq!(
82            &array.dtype().with_nullability(Nullability::Nullable),
83            taken.dtype()
84        );
85        assert_eq!(
86            taken.to_primitive().unwrap().as_slice::<i32>(),
87            &[42, 42, 42]
88        );
89        assert_eq!(
90            taken.validity_mask().unwrap().indices(),
91            AllOr::Some(valid_indices)
92        );
93    }
94
95    #[test]
96    fn take_all_valid_indices() {
97        let array = ConstantArray::new(42, 10).to_array();
98        let taken = take(
99            &array,
100            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
101        )
102        .unwrap();
103        assert_eq!(
104            &array.dtype().with_nullability(Nullability::Nullable),
105            taken.dtype()
106        );
107        assert_eq!(
108            taken.to_primitive().unwrap().as_slice::<i32>(),
109            &[42, 42, 42]
110        );
111        assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
112    }
113
114    #[rstest]
115    #[case(ConstantArray::new(42i32, 5))]
116    #[case(ConstantArray::new(std::f64::consts::PI, 10))]
117    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
118    #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
119    #[case(ConstantArray::new(true, 1))]
120    fn test_take_constant_conformance(#[case] array: ConstantArray) {
121        test_take_conformance(array.as_ref());
122    }
123}