vortex_array/compute/conformance/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexUnwrap;
5use vortex_mask::Mask;
6
7use crate::Array;
8use crate::arrays::BoolArray;
9use crate::compute::mask;
10
11/// Test mask compute function with various array sizes and patterns.
12/// The mask operation sets elements to null where the mask is true.
13pub fn test_mask_conformance(array: &dyn Array) {
14    let len = array.len();
15
16    if len > 0 {
17        test_heterogenous_mask(array);
18        test_empty_mask(array);
19        test_full_mask(array);
20        test_alternating_mask(array);
21        test_sparse_mask(array);
22        test_single_element_mask(array);
23    }
24
25    if len >= 5 {
26        test_double_mask(array);
27    }
28
29    if len > 0 {
30        test_nullable_mask_input(array);
31    }
32}
33
34/// Tests masking with a heterogeneous pattern
35fn test_heterogenous_mask(array: &dyn Array) {
36    let len = array.len();
37
38    // Create a pattern where roughly half the values are masked
39    let mask_pattern: Vec<bool> = (0..len).map(|i| i % 3 != 1).collect();
40    let mask_array = Mask::from_iter(mask_pattern.clone());
41
42    let masked = mask(array, &mask_array).vortex_unwrap();
43    assert_eq!(masked.len(), array.len());
44
45    // Verify masked elements are null and unmasked elements are preserved
46    for (i, &masked_out) in mask_pattern.iter().enumerate() {
47        if masked_out {
48            assert!(!masked.is_valid(i));
49        } else {
50            assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
51        }
52    }
53}
54
55/// Tests that an empty mask (all false) preserves all elements
56fn test_empty_mask(array: &dyn Array) {
57    let len = array.len();
58    let all_unmasked = vec![false; len];
59    let mask_array = Mask::from_iter(all_unmasked);
60
61    let masked = mask(array, &mask_array).vortex_unwrap();
62    assert_eq!(masked.len(), array.len());
63
64    // All elements should be preserved
65    for i in 0..len {
66        assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
67    }
68}
69
70/// Tests that a full mask (all true) makes all elements null
71fn test_full_mask(array: &dyn Array) {
72    let len = array.len();
73    let all_masked = vec![true; len];
74    let mask_array = Mask::from_iter(all_masked);
75
76    let masked = mask(array, &mask_array).vortex_unwrap();
77    assert_eq!(masked.len(), array.len());
78
79    // All elements should be null
80    for i in 0..len {
81        assert!(!masked.is_valid(i));
82    }
83}
84
85/// Tests alternating mask pattern
86fn test_alternating_mask(array: &dyn Array) {
87    let len = array.len();
88    let pattern: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
89    let mask_array = Mask::from_iter(pattern);
90
91    let masked = mask(array, &mask_array).vortex_unwrap();
92    assert_eq!(masked.len(), array.len());
93
94    for i in 0..len {
95        if i % 2 == 0 {
96            assert!(!masked.is_valid(i));
97        } else {
98            assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
99        }
100    }
101}
102
103/// Tests sparse mask (only a few elements masked)
104fn test_sparse_mask(array: &dyn Array) {
105    let len = array.len();
106    if len < 10 {
107        return; // Skip for small arrays
108    }
109
110    // Mask only about 10% of elements
111    let pattern: Vec<bool> = (0..len).map(|i| i % 10 == 0).collect();
112    let mask_array = Mask::from_iter(pattern.clone());
113
114    let masked = mask(array, &mask_array).vortex_unwrap();
115    assert_eq!(masked.len(), array.len());
116
117    // Count how many elements are valid after masking
118    let valid_count = (0..len).filter(|&i| masked.is_valid(i)).count();
119
120    // Count how many elements should be invalid:
121    // - Elements that were masked (pattern[i] == true)
122    // - Elements that were already invalid in the original array
123    let expected_invalid_count = (0..len)
124        .filter(|&i| pattern[i] || !array.is_valid(i))
125        .count();
126
127    assert_eq!(valid_count, len - expected_invalid_count);
128}
129
130/// Tests masking a single element
131fn test_single_element_mask(array: &dyn Array) {
132    let len = array.len();
133
134    // Mask only the first element
135    let mut pattern = vec![false; len];
136    pattern[0] = true;
137    let mask_array = Mask::from_iter(pattern);
138
139    let masked = mask(array, &mask_array).vortex_unwrap();
140    assert!(!masked.is_valid(0));
141
142    for i in 1..len {
143        assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
144    }
145}
146
147/// Tests double masking operations
148fn test_double_mask(array: &dyn Array) {
149    let len = array.len();
150
151    // Create two different mask patterns
152    let mask1_pattern: Vec<bool> = (0..len).map(|i| i % 3 == 0).collect();
153    let mask2_pattern: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
154
155    let mask1 = Mask::from_iter(mask1_pattern.clone());
156    let mask2 = Mask::from_iter(mask2_pattern.clone());
157
158    let first_masked = mask(array, &mask1).vortex_unwrap();
159    let double_masked = mask(&first_masked, &mask2).vortex_unwrap();
160
161    // Elements should be null if either mask is true
162    for i in 0..len {
163        if mask1_pattern[i] || mask2_pattern[i] {
164            assert!(!double_masked.is_valid(i));
165        } else {
166            assert_eq!(
167                double_masked.scalar_at(i),
168                array.scalar_at(i).into_nullable()
169            );
170        }
171    }
172}
173
174/// Tests masking with nullable mask (nulls treated as false)
175fn test_nullable_mask_input(array: &dyn Array) {
176    let len = array.len();
177    if len < 3 {
178        return; // Skip for very small arrays
179    }
180
181    // Create a nullable mask
182    let bool_values: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
183    let validity_values: Vec<bool> = (0..len).map(|i| i % 3 != 0).collect();
184
185    let bool_array = BoolArray::from_iter(bool_values.clone());
186    let validity = crate::validity::Validity::from_iter(validity_values.clone());
187    let nullable_mask = BoolArray::from_bool_buffer(bool_array.boolean_buffer().clone(), validity);
188
189    let mask_array = nullable_mask.to_mask_fill_null_false();
190    let masked = mask(array, &mask_array).vortex_unwrap();
191
192    // Elements are masked only if the mask is true AND valid
193    for i in 0..len {
194        if bool_values[i] && validity_values[i] {
195            assert!(!masked.is_valid(i));
196        } else {
197            assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
198        }
199    }
200}