Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArrayView;
16use vortex_array::ExecutionCtx;
17use vortex_array::ExecutionResult;
18use vortex_array::IntoArray;
19use vortex_array::Precision;
20use vortex_array::TypedArrayRef;
21use vortex_array::arrays::BoolArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::scalar::Scalar;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityVTable;
30use vortex_array::vtable::child_to_validity;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBuffer;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_mask::Mask;
39use vortex_session::VortexSession;
40
41use crate::kernel::PARENT_KERNELS;
42
43/// A [`ByteBool`]-encoded Vortex array.
44pub type ByteBoolArray = Array<ByteBool>;
45
46impl ArrayHash for ByteBoolData {
47    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
48        self.buffer.array_hash(state, precision);
49    }
50}
51
52impl ArrayEq for ByteBoolData {
53    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
54        self.buffer.array_eq(&other.buffer, precision)
55    }
56}
57
58impl VTable for ByteBool {
59    type ArrayData = ByteBoolData;
60
61    type OperationsVTable = Self;
62    type ValidityVTable = Self;
63
64    fn id(&self) -> ArrayId {
65        Self::ID
66    }
67
68    fn validate(
69        &self,
70        data: &Self::ArrayData,
71        dtype: &DType,
72        len: usize,
73        slots: &[Option<ArrayRef>],
74    ) -> VortexResult<()> {
75        let validity = child_to_validity(&slots[VALIDITY_SLOT], dtype.nullability());
76        ByteBoolData::validate(data.buffer(), &validity, dtype, len)
77    }
78
79    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
80        1
81    }
82
83    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
84        match idx {
85            0 => array.buffer().clone(),
86            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
87        }
88    }
89
90    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
91        match idx {
92            0 => Some("values".to_string()),
93            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
94        }
95    }
96
97    fn serialize(
98        _array: ArrayView<'_, Self>,
99        _session: &VortexSession,
100    ) -> VortexResult<Option<Vec<u8>>> {
101        Ok(Some(vec![]))
102    }
103
104    fn deserialize(
105        &self,
106        dtype: &DType,
107        len: usize,
108        metadata: &[u8],
109        buffers: &[BufferHandle],
110        children: &dyn ArrayChildren,
111        _session: &VortexSession,
112    ) -> VortexResult<ArrayParts<Self>> {
113        if !metadata.is_empty() {
114            vortex_bail!(
115                "ByteBoolArray expects empty metadata, got {} bytes",
116                metadata.len()
117            );
118        }
119        let validity = if children.is_empty() {
120            Validity::from(dtype.nullability())
121        } else if children.len() == 1 {
122            let validity = children.get(0, &Validity::DTYPE, len)?;
123            Validity::Array(validity)
124        } else {
125            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
126        };
127
128        if buffers.len() != 1 {
129            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
130        }
131        let buffer = buffers[0].clone();
132
133        let data = ByteBoolData::new(buffer, validity.clone());
134        let slots = ByteBoolData::make_slots(&validity, len);
135        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
136    }
137
138    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
139        SLOT_NAMES[idx].to_string()
140    }
141
142    fn reduce_parent(
143        array: ArrayView<'_, Self>,
144        parent: &ArrayRef,
145        child_idx: usize,
146    ) -> VortexResult<Option<ArrayRef>> {
147        crate::rules::RULES.evaluate(array, parent, child_idx)
148    }
149
150    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
151        let boolean_buffer = BitBuffer::from(array.as_slice());
152        let validity = array.validity()?;
153        Ok(ExecutionResult::done(
154            BoolArray::new(boolean_buffer, validity).into_array(),
155        ))
156    }
157
158    fn execute_parent(
159        array: ArrayView<'_, Self>,
160        parent: &ArrayRef,
161        child_idx: usize,
162        ctx: &mut ExecutionCtx,
163    ) -> VortexResult<Option<ArrayRef>> {
164        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
165    }
166}
167
168/// The validity bitmap indicating which elements are non-null.
169pub(super) const VALIDITY_SLOT: usize = 0;
170pub(super) const NUM_SLOTS: usize = 1;
171pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
172
173#[derive(Clone, Debug)]
174pub struct ByteBoolData {
175    buffer: BufferHandle,
176}
177
178impl Display for ByteBoolData {
179    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
180        Ok(())
181    }
182}
183
184pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
185    fn validity(&self) -> Validity {
186        child_to_validity(
187            &self.as_ref().slots()[VALIDITY_SLOT],
188            self.as_ref().dtype().nullability(),
189        )
190    }
191
192    fn validity_mask(&self) -> Mask {
193        self.validity().to_mask(self.as_ref().len())
194    }
195}
196
197impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
198
199#[derive(Clone, Debug)]
200pub struct ByteBool;
201
202impl ByteBool {
203    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
204
205    pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
206        let dtype = DType::Bool(validity.nullability());
207        let slots = ByteBoolData::make_slots(&validity, buffer.len());
208        let data = ByteBoolData::new(buffer, validity);
209        let len = data.len();
210        unsafe {
211            Array::from_parts_unchecked(
212                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
213            )
214        }
215    }
216
217    /// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
218    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
219        let validity = validity.into();
220        let data = ByteBoolData::from_vec(data, validity.clone());
221        let dtype = DType::Bool(validity.nullability());
222        let len = data.len();
223        let slots = ByteBoolData::make_slots(&validity, len);
224        unsafe {
225            Array::from_parts_unchecked(
226                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
227            )
228        }
229    }
230
231    /// Construct a [`ByteBoolArray`] from optional bools.
232    pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
233        let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
234        let data = ByteBoolData::from(data);
235        let dtype = DType::Bool(validity.nullability());
236        let len = data.len();
237        let slots = ByteBoolData::make_slots(&validity, len);
238        unsafe {
239            Array::from_parts_unchecked(
240                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
241            )
242        }
243    }
244}
245
246impl ByteBoolData {
247    pub fn validate(
248        buffer: &BufferHandle,
249        validity: &Validity,
250        dtype: &DType,
251        len: usize,
252    ) -> VortexResult<()> {
253        let expected_dtype = DType::Bool(validity.nullability());
254        vortex_ensure!(
255            dtype == &expected_dtype,
256            "expected dtype {expected_dtype}, got {dtype}"
257        );
258        vortex_ensure!(
259            buffer.len() == len,
260            "expected len {len}, got {}",
261            buffer.len()
262        );
263        if let Some(vlen) = validity.maybe_len() {
264            vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
265        }
266        Ok(())
267    }
268
269    fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
270        vec![validity_to_child(validity, len)]
271    }
272
273    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
274        let length = buffer.len();
275        if let Some(vlen) = validity.maybe_len()
276            && length != vlen
277        {
278            vortex_panic!(
279                "Buffer length ({}) does not match validity length ({})",
280                length,
281                vlen
282            );
283        }
284        Self { buffer }
285    }
286
287    /// Returns the number of elements in the array.
288    pub fn len(&self) -> usize {
289        self.buffer.len()
290    }
291
292    /// Returns `true` if the array contains no elements.
293    pub fn is_empty(&self) -> bool {
294        self.buffer.len() == 0
295    }
296
297    // TODO(ngates): deprecate construction from vec
298    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
299        let validity = validity.into();
300        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
301        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
302        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
303    }
304
305    pub fn buffer(&self) -> &BufferHandle {
306        &self.buffer
307    }
308
309    pub fn as_slice(&self) -> &[bool] {
310        // Safety: The internal buffer contains byte-sized bools
311        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
312    }
313}
314
315impl ValidityVTable<ByteBool> for ByteBool {
316    fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
317        Ok(ByteBoolArrayExt::validity(&array))
318    }
319}
320
321impl OperationsVTable<ByteBool> for ByteBool {
322    fn scalar_at(
323        array: ArrayView<'_, ByteBool>,
324        index: usize,
325        _ctx: &mut ExecutionCtx,
326    ) -> VortexResult<Scalar> {
327        Ok(Scalar::bool(
328            array.buffer.as_host()[index] == 1,
329            array.dtype().nullability(),
330        ))
331    }
332}
333
334impl From<Vec<bool>> for ByteBoolData {
335    fn from(value: Vec<bool>) -> Self {
336        Self::from_vec(value, Validity::AllValid)
337    }
338}
339
340impl From<Vec<Option<bool>>> for ByteBoolData {
341    fn from(value: Vec<Option<bool>>) -> Self {
342        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
343
344        // This doesn't reallocate, and the compiler even vectorizes it
345        let data = value.into_iter().map(Option::unwrap_or_default).collect();
346
347        Self::from_vec(data, validity)
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use vortex_array::ArrayContext;
354    use vortex_array::IntoArray;
355    use vortex_array::LEGACY_SESSION;
356    use vortex_array::assert_arrays_eq;
357    use vortex_array::serde::SerializeOptions;
358    use vortex_array::serde::SerializedArray;
359    use vortex_array::session::ArraySession;
360    use vortex_array::session::ArraySessionExt;
361    use vortex_buffer::ByteBufferMut;
362    use vortex_session::VortexSession;
363    use vortex_session::registry::ReadContext;
364
365    use super::*;
366
367    #[test]
368    fn test_validity_construction() {
369        let v = vec![true, false];
370        let v_len = v.len();
371
372        let arr = ByteBool::from_vec(v, Validity::AllValid);
373        assert_eq!(v_len, arr.len());
374
375        for idx in 0..arr.len() {
376            assert!(arr.is_valid(idx).unwrap());
377        }
378
379        let v = vec![Some(true), None, Some(false)];
380        let arr = ByteBool::from_option_vec(v);
381        assert!(arr.is_valid(0).unwrap());
382        assert!(!arr.is_valid(1).unwrap());
383        assert!(arr.is_valid(2).unwrap());
384        assert_eq!(arr.len(), 3);
385
386        let v: Vec<Option<bool>> = vec![None, None];
387        let v_len = v.len();
388
389        let arr = ByteBool::from_option_vec(v);
390        assert_eq!(v_len, arr.len());
391
392        for idx in 0..arr.len() {
393            assert!(!arr.is_valid(idx).unwrap());
394        }
395        assert_eq!(arr.len(), 2);
396    }
397
398    #[test]
399    fn test_nullable_bytebool_serde_roundtrip() {
400        let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
401        let dtype = array.dtype().clone();
402        let len = array.len();
403        let session = VortexSession::empty().with::<ArraySession>();
404        session.arrays().register(ByteBool);
405
406        let ctx = ArrayContext::empty();
407        let serialized = array
408            .clone()
409            .into_array()
410            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
411            .unwrap();
412
413        let mut concat = ByteBufferMut::empty();
414        for buf in serialized {
415            concat.extend_from_slice(buf.as_ref());
416        }
417
418        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
419        let decoded = parts
420            .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
421            .unwrap();
422
423        assert_arrays_eq!(decoded, array);
424    }
425}