Skip to main content

vortex_mask/
eq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::mem;
5
6use crate::Mask;
7
8impl PartialEq for Mask {
9    #[inline]
10    fn eq(&self, other: &Self) -> bool {
11        if self.len() != other.len() {
12            return false;
13        }
14        if mem::discriminant(self) == mem::discriminant(other) && !matches!(self, Mask::Values(_)) {
15            return true;
16        }
17        if self.true_count() != other.true_count() {
18            return false;
19        }
20
21        // TODO(ngates): we could compare by indices if density is low enough
22        self.bit_buffer() == other.bit_buffer()
23    }
24}
25
26impl Eq for Mask {}
27
28#[cfg(test)]
29mod test {
30    use vortex_buffer::BitBuffer;
31
32    use crate::Mask;
33
34    #[test]
35    fn filter_mask_eq() {
36        assert_eq!(Mask::new_true(5), Mask::from_buffer(BitBuffer::new_set(5)));
37        assert_eq!(
38            Mask::new_false(5),
39            Mask::from_buffer(BitBuffer::new_unset(5))
40        );
41        assert_eq!(
42            Mask::from_indices(5, vec![0, 2, 3]),
43            Mask::from_slices(5, vec![(0, 1), (2, 4)])
44        );
45        assert_eq!(
46            Mask::from_indices(5, vec![0, 2, 3]),
47            Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false]))
48        );
49    }
50
51    #[test]
52    fn test_mask_eq_different_lengths() {
53        let mask1 = Mask::new_true(5);
54        let mask2 = Mask::new_true(3);
55        assert_ne!(mask1, mask2);
56    }
57
58    #[test]
59    fn test_mask_eq_different_true_counts() {
60        let mask1 = Mask::from_buffer(BitBuffer::from_iter([true, true, false]));
61        let mask2 = Mask::from_buffer(BitBuffer::from_iter([true, false, false]));
62        assert_ne!(mask1, mask2);
63    }
64
65    #[test]
66    fn test_mask_eq_same_count_different_positions() {
67        let mask1 = Mask::from_buffer(BitBuffer::from_iter([true, false, false]));
68        let mask2 = Mask::from_buffer(BitBuffer::from_iter([false, true, false]));
69        assert_ne!(mask1, mask2);
70    }
71
72    #[test]
73    fn test_mask_eq_all_variants() {
74        // Test AllTrue == AllTrue
75        let all_true1 = Mask::new_true(5);
76        let all_true2 = Mask::new_true(5);
77        assert_eq!(all_true1, all_true2);
78
79        // Test AllFalse == AllFalse
80        let all_false1 = Mask::new_false(5);
81        let all_false2 = Mask::new_false(5);
82        assert_eq!(all_false1, all_false2);
83
84        // Test AllTrue != AllFalse
85        assert_ne!(all_true1, all_false1);
86
87        // Test Values == Values
88        let values1 = Mask::from_buffer(BitBuffer::from_iter([true, false, true]));
89        let values2 = Mask::from_buffer(BitBuffer::from_iter([true, false, true]));
90        assert_eq!(values1, values2);
91
92        // Test AllTrue != Values (even if all values are true)
93        let all_true_values = Mask::from_buffer(BitBuffer::new_set(5));
94        assert_eq!(all_true1, all_true_values); // They should be equal
95
96        // Test AllFalse != Values (even if all values are false)
97        let all_false_values = Mask::from_buffer(BitBuffer::new_unset(5));
98        assert_eq!(all_false1, all_false_values); // They should be equal
99    }
100
101    #[test]
102    fn test_mask_eq_reflexive() {
103        // Test that a mask equals itself
104        let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, false, true]));
105        assert_eq!(mask, mask);
106    }
107
108    #[test]
109    fn test_mask_eq_symmetric() {
110        // Test that if a == b then b == a
111        let mask1 = Mask::from_indices(5, vec![0, 2, 4]);
112        let mask2 = Mask::from_slices(5, vec![(0, 1), (2, 3), (4, 5)]);
113        assert_eq!(mask1, mask2);
114        assert_eq!(mask2, mask1);
115    }
116
117    #[test]
118    fn test_mask_eq_transitive() {
119        // Test that if a == b and b == c then a == c
120        let mask1 = Mask::from_indices(5, vec![1, 3]);
121        let mask2 = Mask::from_slices(5, vec![(1, 2), (3, 4)]);
122        let mask3 = Mask::from_buffer(BitBuffer::from_iter([false, true, false, true, false]));
123
124        assert_eq!(mask1, mask2);
125        assert_eq!(mask2, mask3);
126        assert_eq!(mask1, mask3);
127    }
128
129    #[test]
130    fn test_mask_eq_empty() {
131        // All empty masks become AllFalse regardless of input type
132        let empty1 = Mask::new_true(0);
133        let empty2 = Mask::new_false(0);
134        let empty3 = Mask::from_buffer(BitBuffer::new_set(0));
135        let empty4 = Mask::from_buffer(BitBuffer::new_unset(0));
136
137        // All should be AllFalse(0) when created from buffer
138        assert!(matches!(empty3, Mask::AllFalse(0)));
139        assert!(matches!(empty4, Mask::AllFalse(0)));
140
141        // new_true(0) is AllTrue(0), new_false(0) is AllFalse(0)
142        assert!(matches!(empty1, Mask::AllTrue(0)));
143        assert!(matches!(empty2, Mask::AllFalse(0)));
144    }
145
146    #[test]
147    fn test_mask_eq_different_representations() {
148        // Test that masks with the same logical values but different internal representations are equal
149        let indices = vec![0, 1, 2, 5, 6, 9];
150        let slices = vec![(0, 3), (5, 7), (9, 10)];
151        let buffer = BitBuffer::from_iter([
152            true, true, true, false, false, true, true, false, false, true,
153        ]);
154
155        let mask1 = Mask::from_indices(10, indices);
156        let mask2 = Mask::from_slices(10, slices);
157        let mask3 = Mask::from_buffer(buffer);
158
159        assert_eq!(mask1, mask2);
160        assert_eq!(mask2, mask3);
161        assert_eq!(mask1, mask3);
162    }
163}