vortex_array/arrays/masked/vtable/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5
6use crate::Array;
7use crate::Canonical;
8use crate::arrays::ConstantVTable;
9use crate::arrays::MaskedArray;
10use crate::arrays::MaskedVTable;
11use crate::compute::mask;
12use crate::vtable::CanonicalVTable;
13
14impl CanonicalVTable<MaskedVTable> for MaskedVTable {
15    fn canonicalize(array: &MaskedArray) -> Canonical {
16        if array.child.is::<ConstantVTable>() {
17            // To allow constant array to produce masked array from mask call, we have to unwrap constant here and canonicalize it first
18            mask(
19                array.child.to_canonical().as_ref(),
20                &!array.validity.to_mask(array.len()),
21            )
22            .vortex_expect("constant masked to canonical")
23            .to_canonical()
24        } else {
25            array
26                .masked_child()
27                .vortex_expect("masked child of a masked array")
28                .to_canonical()
29        }
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use rstest::rstest;
36    use vortex_dtype::Nullability;
37
38    use super::*;
39    use crate::IntoArray;
40    use crate::ToCanonical;
41    use crate::arrays::PrimitiveArray;
42    use crate::validity::Validity;
43
44    #[rstest]
45    #[case(
46        MaskedArray::try_new(
47            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
48            Validity::AllValid
49        ).unwrap(),
50        Nullability::Nullable
51    )]
52    #[case(
53        MaskedArray::try_new(
54            PrimitiveArray::from_iter([1i32, 2, 3]).into_array(),
55            Validity::from_iter([true, false, true])
56        ).unwrap(),
57        Nullability::Nullable
58    )]
59    fn test_canonical_nullability(
60        #[case] array: MaskedArray,
61        #[case] expected_nullability: Nullability,
62    ) {
63        let canonical = array.to_canonical();
64        assert_eq!(
65            canonical.as_ref().dtype().nullability(),
66            expected_nullability
67        );
68        assert_eq!(canonical.as_ref().dtype(), array.dtype());
69    }
70
71    #[test]
72    fn test_canonical_with_nulls() {
73        let array = MaskedArray::try_new(
74            PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array(),
75            Validity::from_iter([true, false, true, false, true]),
76        )
77        .unwrap();
78
79        let canonical = array.to_canonical();
80        let prim = canonical.as_ref().to_primitive();
81
82        // Check that null positions match validity.
83        assert_eq!(prim.valid_count(), 3);
84        assert!(prim.is_valid(0));
85        assert!(!prim.is_valid(1));
86        assert!(prim.is_valid(2));
87        assert!(!prim.is_valid(3));
88        assert!(prim.is_valid(4));
89    }
90
91    #[test]
92    fn test_canonical_all_valid() {
93        let array = MaskedArray::try_new(
94            PrimitiveArray::from_iter([10i32, 20, 30]).into_array(),
95            Validity::AllValid,
96        )
97        .unwrap();
98
99        let canonical = array.to_canonical();
100        assert_eq!(canonical.as_ref().valid_count(), 3);
101        assert_eq!(
102            canonical.as_ref().dtype().nullability(),
103            Nullability::Nullable
104        );
105    }
106}