vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Range;
7
8use vortex_array::ArrayBufferVisitor;
9use vortex_array::ArrayChildVisitor;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::EmptyMetadata;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::arrays::BoolArray;
18use vortex_array::buffer::BufferHandle;
19use vortex_array::serde::ArrayChildren;
20use vortex_array::stats::ArrayStats;
21use vortex_array::stats::StatsSetRef;
22use vortex_array::validity::Validity;
23use vortex_array::vtable;
24use vortex_array::vtable::ArrayId;
25use vortex_array::vtable::ArrayVTable;
26use vortex_array::vtable::ArrayVTableExt;
27use vortex_array::vtable::BaseArrayVTable;
28use vortex_array::vtable::CanonicalVTable;
29use vortex_array::vtable::NotSupported;
30use vortex_array::vtable::OperationsVTable;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityHelper;
33use vortex_array::vtable::ValidityVTableFromValidityHelper;
34use vortex_array::vtable::VisitorVTable;
35use vortex_buffer::BitBuffer;
36use vortex_buffer::ByteBuffer;
37use vortex_dtype::DType;
38use vortex_error::VortexExpect;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_panic;
43use vortex_scalar::Scalar;
44
45vtable!(ByteBool);
46
47impl VTable for ByteBoolVTable {
48    type Array = ByteBoolArray;
49
50    type Metadata = EmptyMetadata;
51
52    type ArrayVTable = Self;
53    type CanonicalVTable = Self;
54    type OperationsVTable = Self;
55    type ValidityVTable = ValidityVTableFromValidityHelper;
56    type VisitorVTable = Self;
57    type ComputeVTable = NotSupported;
58    type EncodeVTable = NotSupported;
59
60    fn id(&self) -> ArrayId {
61        ArrayId::new_ref("vortex.bytebool")
62    }
63
64    fn encoding(_array: &Self::Array) -> ArrayVTable {
65        ByteBoolVTable.as_vtable()
66    }
67
68    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
69        Ok(EmptyMetadata)
70    }
71
72    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
73        Ok(Some(vec![]))
74    }
75
76    fn deserialize(_buffer: &[u8]) -> VortexResult<Self::Metadata> {
77        Ok(EmptyMetadata)
78    }
79
80    fn build(
81        &self,
82        dtype: &DType,
83        len: usize,
84        _metadata: &Self::Metadata,
85        buffers: &[BufferHandle],
86        children: &dyn ArrayChildren,
87    ) -> VortexResult<ByteBoolArray> {
88        let validity = if children.is_empty() {
89            Validity::from(dtype.nullability())
90        } else if children.len() == 1 {
91            let validity = children.get(0, &Validity::DTYPE, len)?;
92            Validity::Array(validity)
93        } else {
94            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
95        };
96
97        if buffers.len() != 1 {
98            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
99        }
100        let buffer = buffers[0].clone().try_to_bytes()?;
101
102        Ok(ByteBoolArray::new(buffer, validity))
103    }
104
105    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
106        vortex_ensure!(
107            children.len() <= 1,
108            "ByteBoolArray expects at most 1 child (validity), got {}",
109            children.len()
110        );
111
112        array.validity = if children.is_empty() {
113            Validity::from(array.dtype.nullability())
114        } else {
115            Validity::Array(children.into_iter().next().vortex_expect("checked"))
116        };
117
118        Ok(())
119    }
120}
121
122#[derive(Clone, Debug)]
123pub struct ByteBoolArray {
124    dtype: DType,
125    buffer: ByteBuffer,
126    validity: Validity,
127    stats_set: ArrayStats,
128}
129
130#[derive(Debug)]
131pub struct ByteBoolVTable;
132
133impl ByteBoolArray {
134    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
135        let length = buffer.len();
136        if let Some(vlen) = validity.maybe_len()
137            && length != vlen
138        {
139            vortex_panic!(
140                "Buffer length ({}) does not match validity length ({})",
141                length,
142                vlen
143            );
144        }
145        Self {
146            dtype: DType::Bool(validity.nullability()),
147            buffer,
148            validity,
149            stats_set: Default::default(),
150        }
151    }
152
153    // TODO(ngates): deprecate construction from vec
154    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
155        let validity = validity.into();
156        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
157        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
158        Self::new(ByteBuffer::from(data), validity)
159    }
160
161    pub fn buffer(&self) -> &ByteBuffer {
162        &self.buffer
163    }
164
165    pub fn as_slice(&self) -> &[bool] {
166        // Safety: The internal buffer contains byte-sized bools
167        unsafe { std::mem::transmute(self.buffer().as_slice()) }
168    }
169}
170
171impl ValidityHelper for ByteBoolArray {
172    fn validity(&self) -> &Validity {
173        &self.validity
174    }
175}
176
177impl BaseArrayVTable<ByteBoolVTable> for ByteBoolVTable {
178    fn len(array: &ByteBoolArray) -> usize {
179        array.buffer.len()
180    }
181
182    fn dtype(array: &ByteBoolArray) -> &DType {
183        &array.dtype
184    }
185
186    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
187        array.stats_set.to_ref(array.as_ref())
188    }
189
190    fn array_hash<H: std::hash::Hasher>(
191        array: &ByteBoolArray,
192        state: &mut H,
193        precision: Precision,
194    ) {
195        array.dtype.hash(state);
196        array.buffer.array_hash(state, precision);
197        array.validity.array_hash(state, precision);
198    }
199
200    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
201        array.dtype == other.dtype
202            && array.buffer.array_eq(&other.buffer, precision)
203            && array.validity.array_eq(&other.validity, precision)
204    }
205}
206
207impl CanonicalVTable<ByteBoolVTable> for ByteBoolVTable {
208    fn canonicalize(array: &ByteBoolArray) -> Canonical {
209        let boolean_buffer = BitBuffer::from(array.as_slice());
210        let validity = array.validity().clone();
211        Canonical::Bool(BoolArray::from_bit_buffer(boolean_buffer, validity))
212    }
213}
214
215impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
216    fn slice(array: &ByteBoolArray, range: Range<usize>) -> ArrayRef {
217        ByteBoolArray::new(
218            array.buffer().slice(range.clone()),
219            array.validity().slice(range),
220        )
221        .into_array()
222    }
223
224    fn scalar_at(array: &ByteBoolArray, index: usize) -> Scalar {
225        Scalar::bool(array.buffer()[index] == 1, array.dtype().nullability())
226    }
227}
228
229impl VisitorVTable<ByteBoolVTable> for ByteBoolVTable {
230    fn visit_buffers(array: &ByteBoolArray, visitor: &mut dyn ArrayBufferVisitor) {
231        visitor.visit_buffer(array.buffer());
232    }
233
234    fn visit_children(array: &ByteBoolArray, visitor: &mut dyn ArrayChildVisitor) {
235        visitor.visit_validity(array.validity(), array.len());
236    }
237}
238
239impl From<Vec<bool>> for ByteBoolArray {
240    fn from(value: Vec<bool>) -> Self {
241        Self::from_vec(value, Validity::AllValid)
242    }
243}
244
245impl From<Vec<Option<bool>>> for ByteBoolArray {
246    fn from(value: Vec<Option<bool>>) -> Self {
247        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
248
249        // This doesn't reallocate, and the compiler even vectorizes it
250        let data = value.into_iter().map(Option::unwrap_or_default).collect();
251
252        Self::from_vec(data, validity)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    // #[cfg_attr(miri, ignore)]
261    // #[test]
262    // fn test_bytebool_metadata() {
263    //     check_metadata(
264    //         "bytebool.metadata",
265    //         SerdeMetadata(ByteBoolMetadata {
266    //             validity: ValidityMetadata::AllValid,
267    //         }),
268    //     );
269    // }
270
271    #[test]
272    fn test_validity_construction() {
273        let v = vec![true, false];
274        let v_len = v.len();
275
276        let arr = ByteBoolArray::from(v);
277        assert_eq!(v_len, arr.len());
278
279        for idx in 0..arr.len() {
280            assert!(arr.is_valid(idx));
281        }
282
283        let v = vec![Some(true), None, Some(false)];
284        let arr = ByteBoolArray::from(v);
285        assert!(arr.is_valid(0));
286        assert!(!arr.is_valid(1));
287        assert!(arr.is_valid(2));
288        assert_eq!(arr.len(), 3);
289
290        let v: Vec<Option<bool>> = vec![None, None];
291        let v_len = v.len();
292
293        let arr = ByteBoolArray::from(v);
294        assert_eq!(v_len, arr.len());
295
296        for idx in 0..arr.len() {
297            assert!(!arr.is_valid(idx));
298        }
299        assert_eq!(arr.len(), 2);
300    }
301}