vortex_bytebool/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::arrays::BoolArray;
5use vortex_array::stats::{ArrayStats, StatsSetRef};
6use vortex_array::validity::Validity;
7use vortex_array::variants::BoolArrayTrait;
8use vortex_array::vtable::{EncodingVTable, VTableRef};
9use vortex_array::{
10    Array, ArrayCanonicalImpl, ArrayImpl, ArrayStatisticsImpl, ArrayValidityImpl,
11    ArrayVariantsImpl, Canonical, EmptyMetadata, Encoding, EncodingId, try_from_array_ref,
12};
13use vortex_buffer::ByteBuffer;
14use vortex_dtype::DType;
15use vortex_error::{VortexResult, vortex_panic};
16use vortex_mask::Mask;
17
18#[derive(Clone, Debug)]
19pub struct ByteBoolArray {
20    dtype: DType,
21    buffer: ByteBuffer,
22    validity: Validity,
23    stats_set: ArrayStats,
24}
25
26try_from_array_ref!(ByteBoolArray);
27
28pub struct ByteBoolEncoding;
29impl Encoding for ByteBoolEncoding {
30    type Array = ByteBoolArray;
31    type Metadata = EmptyMetadata;
32}
33
34impl EncodingVTable for ByteBoolEncoding {
35    fn id(&self) -> EncodingId {
36        EncodingId::new_ref("vortex.bytebool")
37    }
38}
39
40impl ByteBoolArray {
41    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
42        let length = buffer.len();
43        if let Some(vlen) = validity.maybe_len() {
44            if length != vlen {
45                vortex_panic!(
46                    "Buffer length ({}) does not match validity length ({})",
47                    length,
48                    vlen
49                );
50            }
51        }
52        Self {
53            dtype: DType::Bool(validity.nullability()),
54            buffer,
55            validity,
56            stats_set: Default::default(),
57        }
58    }
59
60    // TODO(ngates): deprecate construction from vec
61    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
62        let validity = validity.into();
63        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
64        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
65        Self::new(ByteBuffer::from(data), validity)
66    }
67
68    pub fn buffer(&self) -> &ByteBuffer {
69        &self.buffer
70    }
71
72    pub fn validity(&self) -> &Validity {
73        &self.validity
74    }
75
76    pub fn as_slice(&self) -> &[bool] {
77        // Safety: The internal buffer contains byte-sized bools
78        unsafe { std::mem::transmute(self.buffer().as_slice()) }
79    }
80}
81
82impl ArrayImpl for ByteBoolArray {
83    type Encoding = ByteBoolEncoding;
84
85    fn _len(&self) -> usize {
86        self.buffer.len()
87    }
88
89    fn _dtype(&self) -> &DType {
90        &self.dtype
91    }
92
93    fn _vtable(&self) -> VTableRef {
94        VTableRef::new_ref(&ByteBoolEncoding)
95    }
96}
97
98impl ArrayCanonicalImpl for ByteBoolArray {
99    fn _to_canonical(&self) -> VortexResult<Canonical> {
100        let boolean_buffer = BooleanBuffer::from(self.as_slice());
101        let validity = self.validity().clone();
102        Ok(Canonical::Bool(BoolArray::new(boolean_buffer, validity)))
103    }
104}
105
106impl ArrayStatisticsImpl for ByteBoolArray {
107    fn _stats_ref(&self) -> StatsSetRef<'_> {
108        self.stats_set.to_ref(self)
109    }
110}
111
112impl ArrayValidityImpl for ByteBoolArray {
113    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
114        self.validity.is_valid(index)
115    }
116
117    fn _all_valid(&self) -> VortexResult<bool> {
118        self.validity.all_valid()
119    }
120
121    fn _all_invalid(&self) -> VortexResult<bool> {
122        self.validity.all_invalid()
123    }
124
125    fn _validity_mask(&self) -> VortexResult<Mask> {
126        self.validity.to_logical(self.len())
127    }
128}
129
130impl ArrayVariantsImpl for ByteBoolArray {
131    fn _as_bool_typed(&self) -> Option<&dyn BoolArrayTrait> {
132        Some(self)
133    }
134}
135
136impl BoolArrayTrait for ByteBoolArray {}
137
138impl From<Vec<bool>> for ByteBoolArray {
139    fn from(value: Vec<bool>) -> Self {
140        Self::from_vec(value, Validity::AllValid)
141    }
142}
143
144impl From<Vec<Option<bool>>> for ByteBoolArray {
145    fn from(value: Vec<Option<bool>>) -> Self {
146        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
147
148        // This doesn't reallocate, and the compiler even vectorizes it
149        let data = value.into_iter().map(Option::unwrap_or_default).collect();
150
151        Self::from_vec(data, validity)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157
158    use super::*;
159
160    // #[cfg_attr(miri, ignore)]
161    // #[test]
162    // fn test_bytebool_metadata() {
163    //     check_metadata(
164    //         "bytebool.metadata",
165    //         SerdeMetadata(ByteBoolMetadata {
166    //             validity: ValidityMetadata::AllValid,
167    //         }),
168    //     );
169    // }
170
171    #[test]
172    fn test_validity_construction() {
173        let v = vec![true, false];
174        let v_len = v.len();
175
176        let arr = ByteBoolArray::from(v);
177        assert_eq!(v_len, arr.len());
178
179        for idx in 0..arr.len() {
180            assert!(arr.is_valid(idx).unwrap());
181        }
182
183        let v = vec![Some(true), None, Some(false)];
184        let arr = ByteBoolArray::from(v);
185        assert!(arr.is_valid(0).unwrap());
186        assert!(!arr.is_valid(1).unwrap());
187        assert!(arr.is_valid(2).unwrap());
188        assert_eq!(arr.len(), 3);
189
190        let v: Vec<Option<bool>> = vec![None, None];
191        let v_len = v.len();
192
193        let arr = ByteBoolArray::from(v);
194        assert_eq!(v_len, arr.len());
195
196        for idx in 0..arr.len() {
197            assert!(!arr.is_valid(idx).unwrap());
198        }
199        assert_eq!(arr.len(), 2);
200    }
201}