Skip to main content

vortex_array/arrays/constant/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::array::ArrayView;
12use crate::arrays::Constant;
13use crate::arrays::ConstantArray;
14use crate::arrays::MaskedArray;
15use crate::arrays::dict::TakeReduce;
16use crate::arrays::dict::TakeReduceAdaptor;
17use crate::optimizer::rules::ParentRuleSet;
18use crate::scalar::Scalar;
19use crate::validity::Validity;
20
21impl TakeReduce for Constant {
22    fn take(array: ArrayView<'_, Constant>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
23        let mut ctx = LEGACY_SESSION.create_execution_ctx();
24        let result = match indices
25            .validity()?
26            .execute_mask(indices.len(), &mut ctx)?
27            .bit_buffer()
28        {
29            AllOr::All => {
30                let scalar = Scalar::try_new(
31                    array
32                        .scalar()
33                        .dtype()
34                        .union_nullability(indices.dtype().nullability()),
35                    array.scalar().value().cloned(),
36                )?;
37                ConstantArray::new(scalar, indices.len()).into_array()
38            }
39            AllOr::None => ConstantArray::new(
40                Scalar::null(
41                    array
42                        .dtype()
43                        .union_nullability(indices.dtype().nullability()),
44                ),
45                indices.len(),
46            )
47            .into_array(),
48            AllOr::Some(v) => {
49                let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
50
51                if array.scalar().is_null() {
52                    return Ok(Some(arr));
53                }
54
55                MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
56            }
57        };
58        Ok(Some(result))
59    }
60}
61
62impl Constant {
63    pub const TAKE_RULES: ParentRuleSet<Self> =
64        ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
65}
66
67#[cfg(test)]
68mod tests {
69    use std::f64;
70
71    use rstest::rstest;
72    use vortex_buffer::buffer;
73    use vortex_mask::AllOr;
74
75    use crate::IntoArray;
76    #[expect(deprecated)]
77    use crate::ToCanonical as _;
78    use crate::VortexSessionExecute;
79    use crate::array_session;
80    use crate::arrays::ConstantArray;
81    use crate::arrays::PrimitiveArray;
82    use crate::assert_arrays_eq;
83    use crate::compute::conformance::take::test_take_conformance;
84    use crate::dtype::Nullability;
85    use crate::scalar::Scalar;
86    use crate::validity::Validity;
87
88    #[test]
89    fn take_nullable_indices() {
90        let mut ctx = array_session().create_execution_ctx();
91        let array = ConstantArray::new(42, 10).into_array();
92        let taken = array
93            .take(
94                PrimitiveArray::new(
95                    buffer![0, 5, 7],
96                    Validity::from_iter(vec![false, true, false]),
97                )
98                .into_array(),
99            )
100            .unwrap();
101        let valid_indices: &[usize] = &[1usize];
102        assert_eq!(
103            &array.dtype().with_nullability(Nullability::Nullable),
104            taken.dtype()
105        );
106        assert_arrays_eq!(
107            #[expect(deprecated)]
108            taken.to_primitive(),
109            PrimitiveArray::new(
110                buffer![42i32, 42, 42],
111                Validity::from_iter([false, true, false])
112            ),
113            &mut ctx
114        );
115        assert_eq!(
116            taken
117                .validity()
118                .unwrap()
119                .execute_mask(taken.len(), &mut array_session().create_execution_ctx())
120                .unwrap()
121                .indices(),
122            AllOr::Some(valid_indices)
123        );
124    }
125
126    #[test]
127    fn take_all_valid_indices() {
128        let mut ctx = array_session().create_execution_ctx();
129        let array = ConstantArray::new(42, 10).into_array();
130        let taken = array
131            .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
132            .unwrap();
133        assert_eq!(
134            &array.dtype().with_nullability(Nullability::Nullable),
135            taken.dtype()
136        );
137        assert_arrays_eq!(
138            #[expect(deprecated)]
139            taken.to_primitive(),
140            PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid),
141            &mut ctx
142        );
143        assert_eq!(
144            taken
145                .validity()
146                .unwrap()
147                .execute_mask(taken.len(), &mut array_session().create_execution_ctx())
148                .unwrap()
149                .indices(),
150            AllOr::All
151        );
152    }
153
154    #[rstest]
155    #[case(ConstantArray::new(42i32, 5))]
156    #[case(ConstantArray::new(f64::consts::PI, 10))]
157    #[case(ConstantArray::new(Scalar::from("hello"), 3))]
158    #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
159    #[case(ConstantArray::new(true, 1))]
160    fn test_take_constant_conformance(#[case] array: ConstantArray) {
161        test_take_conformance(&array.into_array());
162    }
163}