vortex_bytebool/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::arrays::BoolArray;
5use vortex_array::stats::{ArrayStats, StatsSetRef};
6use vortex_array::validity::Validity;
7use vortex_array::variants::BoolArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11    ArrayVariantsImpl, Canonical, EmptyMetadata, Encoding, try_from_array_ref,
12};
13use vortex_buffer::ByteBuffer;
14use vortex_dtype::DType;
15use vortex_error::{VortexResult, vortex_panic};
16use vortex_mask::Mask;
17
18#[derive(Clone, Debug)]
19pub struct ByteBoolArray {
20    dtype: DType,
21    buffer: ByteBuffer,
22    validity: Validity,
23    stats_set: ArrayStats,
24}
25
26try_from_array_ref!(ByteBoolArray);
27
28pub struct ByteBoolEncoding;
29impl Encoding for ByteBoolEncoding {
30    type Array = ByteBoolArray;
31    type Metadata = EmptyMetadata;
32}
33
34impl ByteBoolArray {
35    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
36        let length = buffer.len();
37        if let Some(vlen) = validity.maybe_len() {
38            if length != vlen {
39                vortex_panic!(
40                    "Buffer length ({}) does not match validity length ({})",
41                    length,
42                    vlen
43                );
44            }
45        }
46        Self {
47            dtype: DType::Bool(validity.nullability()),
48            buffer,
49            validity,
50            stats_set: Default::default(),
51        }
52    }
53
54    // TODO(ngates): deprecate construction from vec
55    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
56        let validity = validity.into();
57        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
58        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
59        Self::new(ByteBuffer::from(data), validity)
60    }
61
62    pub fn buffer(&self) -> &ByteBuffer {
63        &self.buffer
64    }
65
66    pub fn validity(&self) -> &Validity {
67        &self.validity
68    }
69
70    pub fn as_slice(&self) -> &[bool] {
71        // Safety: The internal buffer contains byte-sized bools
72        unsafe { std::mem::transmute(self.buffer().as_slice()) }
73    }
74}
75
76impl ArrayImpl for ByteBoolArray {
77    type Encoding = ByteBoolEncoding;
78
79    fn _len(&self) -> usize {
80        self.buffer.len()
81    }
82
83    fn _dtype(&self) -> &DType {
84        &self.dtype
85    }
86
87    fn _vtable(&self) -> VTableRef {
88        VTableRef::new_ref(&ByteBoolEncoding)
89    }
90
91    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
92        let validity = if self.validity().is_array() {
93            Validity::Array(children[0].clone())
94        } else {
95            self.validity().clone()
96        };
97
98        Ok(Self::new(self.buffer().clone(), validity))
99    }
100}
101
102impl ArrayCanonicalImpl for ByteBoolArray {
103    fn _to_canonical(&self) -> VortexResult<Canonical> {
104        let boolean_buffer = BooleanBuffer::from(self.as_slice());
105        let validity = self.validity().clone();
106        Ok(Canonical::Bool(BoolArray::new(boolean_buffer, validity)))
107    }
108}
109
110impl ArrayStatisticsImpl for ByteBoolArray {
111    fn _stats_ref(&self) -> StatsSetRef<'_> {
112        self.stats_set.to_ref(self)
113    }
114}
115
116impl ArrayValidityImpl for ByteBoolArray {
117    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
118        self.validity.is_valid(index)
119    }
120
121    fn _all_valid(&self) -> VortexResult<bool> {
122        self.validity.all_valid()
123    }
124
125    fn _all_invalid(&self) -> VortexResult<bool> {
126        self.validity.all_invalid()
127    }
128
129    fn _validity_mask(&self) -> VortexResult<Mask> {
130        self.validity.to_mask(self.len())
131    }
132}
133
134impl ArrayVariantsImpl for ByteBoolArray {
135    fn _as_bool_typed(&self) -> Option<&dyn BoolArrayTrait> {
136        Some(self)
137    }
138}
139
140impl BoolArrayTrait for ByteBoolArray {}
141
142impl From<Vec<bool>> for ByteBoolArray {
143    fn from(value: Vec<bool>) -> Self {
144        Self::from_vec(value, Validity::AllValid)
145    }
146}
147
148impl From<Vec<Option<bool>>> for ByteBoolArray {
149    fn from(value: Vec<Option<bool>>) -> Self {
150        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
151
152        // This doesn't reallocate, and the compiler even vectorizes it
153        let data = value.into_iter().map(Option::unwrap_or_default).collect();
154
155        Self::from_vec(data, validity)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161
162    use super::*;
163
164    // #[cfg_attr(miri, ignore)]
165    // #[test]
166    // fn test_bytebool_metadata() {
167    //     check_metadata(
168    //         "bytebool.metadata",
169    //         SerdeMetadata(ByteBoolMetadata {
170    //             validity: ValidityMetadata::AllValid,
171    //         }),
172    //     );
173    // }
174
175    #[test]
176    fn test_validity_construction() {
177        let v = vec![true, false];
178        let v_len = v.len();
179
180        let arr = ByteBoolArray::from(v);
181        assert_eq!(v_len, arr.len());
182
183        for idx in 0..arr.len() {
184            assert!(arr.is_valid(idx).unwrap());
185        }
186
187        let v = vec![Some(true), None, Some(false)];
188        let arr = ByteBoolArray::from(v);
189        assert!(arr.is_valid(0).unwrap());
190        assert!(!arr.is_valid(1).unwrap());
191        assert!(arr.is_valid(2).unwrap());
192        assert_eq!(arr.len(), 3);
193
194        let v: Vec<Option<bool>> = vec![None, None];
195        let v_len = v.len();
196
197        let arr = ByteBoolArray::from(v);
198        assert_eq!(v_len, arr.len());
199
200        for idx in 0..arr.len() {
201            assert!(!arr.is_valid(idx).unwrap());
202        }
203        assert_eq!(arr.len(), 2);
204    }
205}