vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::{AllOr, Mask};
6use vortex_scalar::Scalar;
7
8use crate::arrays::{ConstantArray, ConstantVTable};
9use crate::builders::builder_with_capacity;
10use crate::compute::{TakeKernel, TakeKernelAdapter};
11use crate::{Array, ArrayRef, IntoArray, register_kernel};
12
13impl TakeKernel for ConstantVTable {
14    fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
15        match indices.validity_mask().boolean_buffer() {
16            AllOr::All => {
17                let scalar = Scalar::new(
18                    array
19                        .scalar()
20                        .dtype()
21                        .union_nullability(indices.dtype().nullability()),
22                    array.scalar().value().clone(),
23                );
24                Ok(ConstantArray::new(scalar, indices.len()).into_array())
25            }
26            AllOr::None => Ok(ConstantArray::new(
27                Scalar::null(
28                    array
29                        .dtype()
30                        .union_nullability(indices.dtype().nullability()),
31                ),
32                indices.len(),
33            )
34            .into_array()),
35            AllOr::Some(v) => {
36                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
37
38                if array.scalar().is_null() {
39                    return Ok(arr);
40                }
41
42                let mut result_builder =
43                    builder_with_capacity(&array.dtype().as_nullable(), indices.len());
44                result_builder.extend_from_array(&arr)?;
45                result_builder.set_validity(Mask::from_buffer(v.clone()));
46                Ok(result_builder.finish())
47            }
48        }
49    }
50}
51
52register_kernel!(TakeKernelAdapter(ConstantVTable).lift());
53
54#[cfg(test)]
55mod tests {
56    use rstest::rstest;
57    use vortex_buffer::buffer;
58    use vortex_dtype::Nullability;
59    use vortex_mask::AllOr;
60    use vortex_scalar::Scalar;
61
62    use crate::arrays::{ConstantArray, PrimitiveArray};
63    use crate::compute::conformance::take::test_take_conformance;
64    use crate::compute::take;
65    use crate::validity::Validity;
66    use crate::{Array, IntoArray, ToCanonical};
67
68    #[test]
69    fn take_nullable_indices() {
70        let array = ConstantArray::new(42, 10).to_array();
71        let taken = take(
72            &array,
73            &PrimitiveArray::new(
74                buffer![0, 5, 7],
75                Validity::from_iter(vec![false, true, false]),
76            )
77            .into_array(),
78        )
79        .unwrap();
80        let valid_indices: &[usize] = &[1usize];
81        assert_eq!(
82            &array.dtype().with_nullability(Nullability::Nullable),
83            taken.dtype()
84        );
85        assert_eq!(
86            taken.to_primitive().unwrap().as_slice::<i32>(),
87            &[42, 42, 42]
88        );
89        assert_eq!(taken.validity_mask().indices(), AllOr::Some(valid_indices));
90    }
91
92    #[test]
93    fn take_all_valid_indices() {
94        let array = ConstantArray::new(42, 10).to_array();
95        let taken = take(
96            &array,
97            &PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array(),
98        )
99        .unwrap();
100        assert_eq!(
101            &array.dtype().with_nullability(Nullability::Nullable),
102            taken.dtype()
103        );
104        assert_eq!(
105            taken.to_primitive().unwrap().as_slice::<i32>(),
106            &[42, 42, 42]
107        );
108        assert_eq!(taken.validity_mask().indices(), AllOr::All);
109    }
110
111    #[rstest]
112    #[case(ConstantArray::new(42i32, 5))]
113    #[case(ConstantArray::new(std::f64::consts::PI, 10))]
114    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
115    #[case(ConstantArray::new(Scalar::null_typed::<i64>(), 5))]
116    #[case(ConstantArray::new(true, 1))]
117    fn test_take_constant_conformance(#[case] array: ConstantArray) {
118        test_take_conformance(array.as_ref());
119    }
120}