Skip to main content

vortex_array/arrays/masked/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::MaskedArray;
10use crate::arrays::MaskedVTable;
11use crate::arrays::TakeReduce;
12use crate::builtins::ArrayBuiltins;
13use crate::scalar::Scalar;
14use crate::vtable::ValidityHelper;
15
16impl TakeReduce for MaskedVTable {
17    fn take(array: &MaskedArray, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
18        let taken_child = if !indices.all_valid()? {
19            // This is safe because we'll mask out these positions in the validity.
20            let fill_scalar = Scalar::zero_value(indices.dtype());
21            let filled_take_indices = indices.to_array().fill_null(fill_scalar)?;
22            array.child.take(filled_take_indices)?
23        } else {
24            array.child.take(indices.to_array())?
25        };
26
27        // Compute the new validity by taking from array's validity and merging with indices validity
28        let taken_validity = array.validity().take(indices)?;
29
30        // Construct new MaskedArray
31        Ok(Some(
32            MaskedArray::try_new(taken_child, taken_validity)?.into_array(),
33        ))
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use rstest::rstest;
40
41    use crate::IntoArray;
42    use crate::arrays::MaskedArray;
43    use crate::arrays::PrimitiveArray;
44    use crate::compute::conformance::take::test_take_conformance;
45    use crate::validity::Validity;
46
47    #[rstest]
48    #[case(
49        MaskedArray::try_new(
50            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
51            Validity::from_iter([true, true, false, true, false])
52        ).unwrap()
53    )]
54    #[case(
55        MaskedArray::try_new(
56            PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
57            Validity::AllValid
58        ).unwrap()
59    )]
60    #[case(
61        MaskedArray::try_new(
62            PrimitiveArray::from_iter(0..100).into_array(),
63            Validity::from_iter((0..100).map(|i| i % 3 != 0))
64        ).unwrap()
65    )]
66    fn test_take_masked_conformance(#[case] array: MaskedArray) {
67        test_take_conformance(&array.to_array());
68    }
69}