Skip to main content

vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::arrays::ConstantArray;
11use crate::arrays::ConstantVTable;
12use crate::arrays::MaskedArray;
13use crate::arrays::TakeReduce;
14use crate::arrays::TakeReduceAdaptor;
15use crate::optimizer::rules::ParentRuleSet;
16use crate::scalar::Scalar;
17use crate::validity::Validity;
18
19impl TakeReduce for ConstantVTable {
20    fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult<Option<ArrayRef>> {
21        let result = match indices.validity_mask()?.bit_buffer() {
22            AllOr::All => {
23                let scalar = Scalar::try_new(
24                    array
25                        .scalar()
26                        .dtype()
27                        .union_nullability(indices.dtype().nullability()),
28                    array.scalar().value().cloned(),
29                )?;
30                ConstantArray::new(scalar, indices.len()).into_array()
31            }
32            AllOr::None => ConstantArray::new(
33                Scalar::null(
34                    array
35                        .dtype()
36                        .union_nullability(indices.dtype().nullability()),
37                ),
38                indices.len(),
39            )
40            .into_array(),
41            AllOr::Some(v) => {
42                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
43
44                if array.scalar().is_null() {
45                    return Ok(Some(arr));
46                }
47
48                MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
49            }
50        };
51        Ok(Some(result))
52    }
53}
54
55impl ConstantVTable {
56    pub const TAKE_RULES: ParentRuleSet<Self> =
57        ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
58}
59
60#[cfg(test)]
61mod tests {
62    use rstest::rstest;
63    use vortex_buffer::buffer;
64    use vortex_dtype::Nullability;
65    use vortex_mask::AllOr;
66
67    use crate::Array;
68    use crate::IntoArray;
69    use crate::ToCanonical;
70    use crate::arrays::ConstantArray;
71    use crate::arrays::PrimitiveArray;
72    use crate::assert_arrays_eq;
73    use crate::compute::conformance::take::test_take_conformance;
74    use crate::scalar::Scalar;
75    use crate::validity::Validity;
76
77    #[test]
78    fn take_nullable_indices() {
79        let array = ConstantArray::new(42, 10).to_array();
80        let taken = array
81            .take(
82                PrimitiveArray::new(
83                    buffer![0, 5, 7],
84                    Validity::from_iter(vec![false, true, false]),
85                )
86                .into_array(),
87            )
88            .unwrap();
89        let valid_indices: &[usize] = &[1usize];
90        assert_eq!(
91            &array.dtype().with_nullability(Nullability::Nullable),
92            taken.dtype()
93        );
94        assert_arrays_eq!(
95            taken.to_primitive(),
96            PrimitiveArray::new(
97                buffer![42i32, 42, 42],
98                Validity::from_iter([false, true, false])
99            )
100        );
101        assert_eq!(
102            taken.validity_mask().unwrap().indices(),
103            AllOr::Some(valid_indices)
104        );
105    }
106
107    #[test]
108    fn take_all_valid_indices() {
109        let array = ConstantArray::new(42, 10).to_array();
110        let taken = array
111            .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
112            .unwrap();
113        assert_eq!(
114            &array.dtype().with_nullability(Nullability::Nullable),
115            taken.dtype()
116        );
117        assert_arrays_eq!(
118            taken.to_primitive(),
119            PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid)
120        );
121        assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
122    }
123
124    #[rstest]
125    #[case(ConstantArray::new(42i32, 5))]
126    #[case(ConstantArray::new(std::f64::consts::PI, 10))]
127    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
128    #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
129    #[case(ConstantArray::new(true, 1))]
130    fn test_take_constant_conformance(#[case] array: ConstantArray) {
131        test_take_conformance(array.as_ref());
132    }
133}