vortex_array/arrays/masked/vtable/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5
6use crate::arrays::{ConstantVTable, MaskedArray, MaskedVTable};
7use crate::compute::mask;
8use crate::vtable::CanonicalVTable;
9use crate::{Array, Canonical};
10
11impl CanonicalVTable<MaskedVTable> for MaskedVTable {
12    fn canonicalize(array: &MaskedArray) -> Canonical {
13        if array.child.is::<ConstantVTable>() {
14            // To allow constant array to produce masked array from mask call, we have to unwrap constant here and canonicalize it first
15            mask(
16                array.child.to_canonical().as_ref(),
17                &!array.validity.to_mask(array.len()),
18            )
19            .vortex_expect("constant masked to canonical")
20            .to_canonical()
21        } else {
22            array
23                .masked_child()
24                .vortex_expect("masked child of a masked array")
25                .to_canonical()
26        }
27    }
28}
29
30#[cfg(test)]
31mod tests {
32    use rstest::rstest;
33    use vortex_dtype::Nullability;
34
35    use super::*;
36    use crate::arrays::PrimitiveArray;
37    use crate::validity::Validity;
38    use crate::{IntoArray, ToCanonical};
39
40    #[rstest]
41    #[case(
42        MaskedArray::try_new(
43            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
44            Validity::AllValid
45        ).unwrap(),
46        Nullability::Nullable
47    )]
48    #[case(
49        MaskedArray::try_new(
50            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
51            Validity::from_iter([true, false, true])
52        ).unwrap(),
53        Nullability::Nullable
54    )]
55    fn test_canonical_nullability(
56        #[case] array: MaskedArray,
57        #[case] expected_nullability: Nullability,
58    ) {
59        let canonical = array.to_canonical();
60        assert_eq!(
61            canonical.as_ref().dtype().nullability(),
62            expected_nullability
63        );
64        assert_eq!(canonical.as_ref().dtype(), array.dtype());
65    }
66
67    #[test]
68    fn test_canonical_with_nulls() {
69        let array = MaskedArray::try_new(
70            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
71            Validity::from_iter([true, false, true, false, true]),
72        )
73        .unwrap();
74
75        let canonical = array.to_canonical();
76        let prim = canonical.as_ref().to_primitive();
77
78        // Check that null positions match validity.
79        assert_eq!(prim.valid_count(), 3);
80        assert!(prim.is_valid(0));
81        assert!(!prim.is_valid(1));
82        assert!(prim.is_valid(2));
83        assert!(!prim.is_valid(3));
84        assert!(prim.is_valid(4));
85    }
86
87    #[test]
88    fn test_canonical_all_valid() {
89        let array = MaskedArray::try_new(
90            PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
91            Validity::AllValid,
92        )
93        .unwrap();
94
95        let canonical = array.to_canonical();
96        assert_eq!(canonical.as_ref().valid_count(), 3);
97        assert_eq!(
98            canonical.as_ref().dtype().nullability(),
99            Nullability::Nullable
100        );
101    }
102}