vortex_vector/bool/
vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Definition and implementation of [`BoolVector`].
5
6use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::RangeBounds;
9
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_ensure;
14use vortex_mask::Mask;
15
16use crate::VectorOps;
17use crate::bool::BoolScalar;
18use crate::bool::BoolVectorMut;
19
20/// An immutable vector of boolean values.
21///
22/// Internally, this `BoolVector` is a wrapper around a [`BitBuffer`] and a validity mask.
23#[derive(Debug, Clone, Eq)]
24pub struct BoolVector {
25    /// The bits that we use to represent booleans.
26    pub(super) bits: BitBuffer,
27    /// The validity mask (where `true` represents an element is **not** null).
28    pub(super) validity: Mask,
29}
30
31impl PartialEq for BoolVector {
32    fn eq(&self, other: &Self) -> bool {
33        if self.len() != other.len() {
34            return false;
35        }
36        // Validity patterns must match
37        if self.validity != other.validity {
38            return false;
39        }
40        // Use XNOR comparison: bits are equal where !(lhs ^ rhs) is true
41        let lhs_chunks = self.bits.chunks();
42        let rhs_chunks = other.bits.chunks();
43        let validity_bits = self.validity.to_bit_buffer();
44        let validity_chunks = validity_bits.chunks();
45
46        // For equality: check that !(lhs ^ rhs) & validity == validity at each chunk
47        for ((lhs, rhs), valid) in lhs_chunks
48            .iter_padded()
49            .zip(rhs_chunks.iter_padded())
50            .zip(validity_chunks.iter_padded())
51        {
52            let equal_bits = !(lhs ^ rhs); // XNOR: true where bits are equal
53            if (equal_bits & valid) != valid {
54                return false;
55            }
56        }
57        true
58    }
59}
60
61impl BoolVector {
62    /// Creates a new [`BoolVector`] from the given bits and validity mask.
63    ///
64    /// # Panics
65    ///
66    /// Panics if the length of the validity mask does not match the length of the bits.
67    pub fn new(bits: BitBuffer, validity: Mask) -> Self {
68        Self::try_new(bits, validity).vortex_expect("Failed to create `BoolVector`")
69    }
70
71    /// Tries to create a new [`BoolVector`] from the given bits and validity mask.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the length of the validity mask does not match the length of the bits.
76    pub fn try_new(bits: BitBuffer, validity: Mask) -> VortexResult<Self> {
77        vortex_ensure!(
78            validity.len() == bits.len(),
79            "`BoolVector` validity mask must have the same length as bits"
80        );
81
82        Ok(Self { bits, validity })
83    }
84
85    /// Creates a new [`BoolVector`] from the given bits and validity mask without validation.
86    ///
87    /// # Safety
88    ///
89    /// The caller must ensure that the validity mask has the same length as the bits.
90    pub unsafe fn new_unchecked(bits: BitBuffer, validity: Mask) -> Self {
91        if cfg!(debug_assertions) {
92            Self::new(bits, validity)
93        } else {
94            Self { bits, validity }
95        }
96    }
97
98    /// Decomposes the boolean vector into its constituent parts (bit buffer and validity).
99    pub fn into_parts(self) -> (BitBuffer, Mask) {
100        (self.bits, self.validity)
101    }
102
103    /// Returns the bits buffer of the boolean vector.
104    pub fn bits(&self) -> &BitBuffer {
105        &self.bits
106    }
107
108    /// Consumes the boolean vector and returns the bits buffer.
109    pub fn into_bits(self) -> BitBuffer {
110        self.bits
111    }
112
113    /// Gets a nullable element at the given index, panicking on out-of-bounds.
114    ///
115    /// If the element at the given index is null, returns `None`. Otherwise, returns `Some(x)`,
116    /// where `x: bool`.
117    ///
118    /// Note that this `get` method is different from the standard library [`slice::get`], which
119    /// returns `None` if the index is out of bounds. This method will panic if the index is out of
120    /// bounds, and return `None` if the element is null.
121    ///
122    /// # Panics
123    ///
124    /// Panics if the index is out of bounds.
125    pub fn get(&self, index: usize) -> Option<bool> {
126        self.validity.value(index).then(|| self.bits.value(index))
127    }
128}
129
130impl VectorOps for BoolVector {
131    type Mutable = BoolVectorMut;
132    type Scalar = BoolScalar;
133
134    fn len(&self) -> usize {
135        debug_assert!(self.validity.len() == self.bits.len());
136        self.bits.len()
137    }
138
139    fn validity(&self) -> &Mask {
140        &self.validity
141    }
142
143    fn mask_validity(&mut self, mask: &Mask) {
144        self.validity = self.validity.bitand(mask);
145    }
146
147    fn scalar_at(&self, index: usize) -> BoolScalar {
148        assert!(index < self.len());
149
150        let is_valid = self.validity.value(index);
151        let value = is_valid.then(|| self.bits.value(index));
152
153        BoolScalar::new(value)
154    }
155
156    fn slice(&self, range: impl RangeBounds<usize> + Clone + Debug) -> Self {
157        let bits = self.bits.slice(range.clone());
158        let validity = self.validity.slice(range);
159        Self { bits, validity }
160    }
161
162    fn clear(&mut self) {
163        self.bits.clear();
164        self.validity.clear();
165    }
166
167    fn try_into_mut(self) -> Result<BoolVectorMut, Self> {
168        let bits = match self.bits.try_into_mut() {
169            Ok(bits) => bits,
170            Err(bits) => {
171                return Err(Self {
172                    bits,
173                    validity: self.validity,
174                });
175            }
176        };
177
178        match self.validity.try_into_mut() {
179            Ok(validity_mut) => Ok(BoolVectorMut {
180                bits,
181                validity: validity_mut,
182            }),
183            Err(validity) => Err(Self {
184                bits: bits.freeze(),
185                validity,
186            }),
187        }
188    }
189
190    fn into_mut(self) -> BoolVectorMut {
191        BoolVectorMut {
192            bits: self.bits.into_mut(),
193            validity: self.validity.into_mut(),
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use vortex_buffer::BitBuffer;
201    use vortex_mask::Mask;
202
203    use super::*;
204
205    #[test]
206    fn test_bool_vector_eq_with_validity_127() {
207        // Test with 127 elements (not a multiple of 64, tests edge cases)
208        let len = 127;
209
210        // Create bits: alternating true/false pattern
211        let bits1: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
212        let mut bits2: Vec<bool> = bits1.clone();
213
214        // Create validity: every 3rd element is invalid
215        let validity_bools: Vec<bool> = (0..len).map(|i| i % 3 != 0).collect();
216        let validity = Mask::from_buffer(BitBuffer::from(validity_bools));
217
218        let v1 = BoolVector::new(BitBuffer::from(bits1.clone()), validity.clone());
219        let v2 = BoolVector::new(BitBuffer::from(bits2.clone()), validity.clone());
220
221        // Should be equal - same bits at valid positions
222        assert_eq!(v1, v2);
223
224        // Now modify bits2 at an INVALID position - should still be equal
225        bits2[0] = !bits2[0]; // Flip bit 0, which is invalid (0 % 3 == 0)
226        let v3 = BoolVector::new(BitBuffer::from(bits2.clone()), validity.clone());
227        assert_eq!(v1, v3);
228
229        // Now modify bits2 at a VALID position - should NOT be equal
230        bits2[1] = !bits2[1]; // Flip bit 1, which is valid (1 % 3 != 0)
231        let v4 = BoolVector::new(BitBuffer::from(bits2), validity);
232        assert_ne!(v1, v4);
233
234        // Test with different validity patterns - should NOT be equal
235        let validity2 = Mask::new_true(len);
236        let v5 = BoolVector::new(BitBuffer::from(bits1), validity2);
237        assert_ne!(v1, v5);
238    }
239}