Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArrayView;
16use vortex_array::ExecutionCtx;
17use vortex_array::ExecutionResult;
18use vortex_array::IntoArray;
19use vortex_array::Precision;
20use vortex_array::TypedArrayRef;
21use vortex_array::arrays::BoolArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::scalar::Scalar;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityVTable;
30use vortex_array::vtable::child_to_validity;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBuffer;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_session::VortexSession;
39use vortex_session::registry::CachedId;
40
41use crate::kernel::PARENT_KERNELS;
42
43/// A [`ByteBool`]-encoded Vortex array.
44pub type ByteBoolArray = Array<ByteBool>;
45
46impl ArrayHash for ByteBoolData {
47    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
48        self.buffer.array_hash(state, precision);
49    }
50}
51
52impl ArrayEq for ByteBoolData {
53    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
54        self.buffer.array_eq(&other.buffer, precision)
55    }
56}
57
58impl VTable for ByteBool {
59    type ArrayData = ByteBoolData;
60
61    type OperationsVTable = Self;
62    type ValidityVTable = Self;
63
64    fn id(&self) -> ArrayId {
65        static ID: CachedId = CachedId::new("vortex.bytebool");
66        *ID
67    }
68
69    fn validate(
70        &self,
71        data: &Self::ArrayData,
72        dtype: &DType,
73        len: usize,
74        slots: &[Option<ArrayRef>],
75    ) -> VortexResult<()> {
76        let validity = child_to_validity(&slots[VALIDITY_SLOT], dtype.nullability());
77        ByteBoolData::validate(data.buffer(), &validity, dtype, len)
78    }
79
80    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
81        1
82    }
83
84    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
85        match idx {
86            0 => array.buffer().clone(),
87            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
88        }
89    }
90
91    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
92        match idx {
93            0 => Some("values".to_string()),
94            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
95        }
96    }
97
98    fn serialize(
99        _array: ArrayView<'_, Self>,
100        _session: &VortexSession,
101    ) -> VortexResult<Option<Vec<u8>>> {
102        Ok(Some(vec![]))
103    }
104
105    fn deserialize(
106        &self,
107        dtype: &DType,
108        len: usize,
109        metadata: &[u8],
110        buffers: &[BufferHandle],
111        children: &dyn ArrayChildren,
112        _session: &VortexSession,
113    ) -> VortexResult<ArrayParts<Self>> {
114        if !metadata.is_empty() {
115            vortex_bail!(
116                "ByteBoolArray expects empty metadata, got {} bytes",
117                metadata.len()
118            );
119        }
120        let validity = if children.is_empty() {
121            Validity::from(dtype.nullability())
122        } else if children.len() == 1 {
123            let validity = children.get(0, &Validity::DTYPE, len)?;
124            Validity::Array(validity)
125        } else {
126            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
127        };
128
129        if buffers.len() != 1 {
130            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
131        }
132        let buffer = buffers[0].clone();
133
134        let data = ByteBoolData::new(buffer, validity.clone());
135        let slots = ByteBoolData::make_slots(&validity, len);
136        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
137    }
138
139    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
140        SLOT_NAMES[idx].to_string()
141    }
142
143    fn reduce_parent(
144        array: ArrayView<'_, Self>,
145        parent: &ArrayRef,
146        child_idx: usize,
147    ) -> VortexResult<Option<ArrayRef>> {
148        crate::rules::RULES.evaluate(array, parent, child_idx)
149    }
150
151    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
152        let boolean_buffer = BitBuffer::from(array.as_slice());
153        let validity = array.validity()?;
154        Ok(ExecutionResult::done(
155            BoolArray::new(boolean_buffer, validity).into_array(),
156        ))
157    }
158
159    fn execute_parent(
160        array: ArrayView<'_, Self>,
161        parent: &ArrayRef,
162        child_idx: usize,
163        ctx: &mut ExecutionCtx,
164    ) -> VortexResult<Option<ArrayRef>> {
165        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
166    }
167}
168
169/// The validity bitmap indicating which elements are non-null.
170pub(super) const VALIDITY_SLOT: usize = 0;
171pub(super) const NUM_SLOTS: usize = 1;
172pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
173
174#[derive(Clone, Debug)]
175pub struct ByteBoolData {
176    buffer: BufferHandle,
177}
178
179impl Display for ByteBoolData {
180    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
181        Ok(())
182    }
183}
184
185pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
186    fn validity(&self) -> Validity {
187        child_to_validity(
188            &self.as_ref().slots()[VALIDITY_SLOT],
189            self.as_ref().dtype().nullability(),
190        )
191    }
192}
193
194impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
195
196#[derive(Clone, Debug)]
197pub struct ByteBool;
198
199impl ByteBool {
200    pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
201        let dtype = DType::Bool(validity.nullability());
202        let slots = ByteBoolData::make_slots(&validity, buffer.len());
203        let data = ByteBoolData::new(buffer, validity);
204        let len = data.len();
205        unsafe {
206            Array::from_parts_unchecked(
207                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
208            )
209        }
210    }
211
212    /// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
213    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
214        let validity = validity.into();
215        let data = ByteBoolData::from_vec(data, validity.clone());
216        let dtype = DType::Bool(validity.nullability());
217        let len = data.len();
218        let slots = ByteBoolData::make_slots(&validity, len);
219        unsafe {
220            Array::from_parts_unchecked(
221                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
222            )
223        }
224    }
225
226    /// Construct a [`ByteBoolArray`] from optional bools.
227    pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
228        let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
229        let data = ByteBoolData::from(data);
230        let dtype = DType::Bool(validity.nullability());
231        let len = data.len();
232        let slots = ByteBoolData::make_slots(&validity, len);
233        unsafe {
234            Array::from_parts_unchecked(
235                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
236            )
237        }
238    }
239}
240
241impl ByteBoolData {
242    pub fn validate(
243        buffer: &BufferHandle,
244        validity: &Validity,
245        dtype: &DType,
246        len: usize,
247    ) -> VortexResult<()> {
248        let expected_dtype = DType::Bool(validity.nullability());
249        vortex_ensure!(
250            dtype == &expected_dtype,
251            "expected dtype {expected_dtype}, got {dtype}"
252        );
253        vortex_ensure!(
254            buffer.len() == len,
255            "expected len {len}, got {}",
256            buffer.len()
257        );
258        if let Some(vlen) = validity.maybe_len() {
259            vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
260        }
261        Ok(())
262    }
263
264    fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
265        vec![validity_to_child(validity, len)]
266    }
267
268    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
269        let length = buffer.len();
270        if let Some(vlen) = validity.maybe_len()
271            && length != vlen
272        {
273            vortex_panic!(
274                "Buffer length ({}) does not match validity length ({})",
275                length,
276                vlen
277            );
278        }
279        Self { buffer }
280    }
281
282    /// Returns the number of elements in the array.
283    pub fn len(&self) -> usize {
284        self.buffer.len()
285    }
286
287    /// Returns `true` if the array contains no elements.
288    pub fn is_empty(&self) -> bool {
289        self.buffer.len() == 0
290    }
291
292    // TODO(ngates): deprecate construction from vec
293    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
294        let validity = validity.into();
295        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
296        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
297        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
298    }
299
300    pub fn buffer(&self) -> &BufferHandle {
301        &self.buffer
302    }
303
304    pub fn as_slice(&self) -> &[bool] {
305        // Safety: The internal buffer contains byte-sized bools
306        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
307    }
308}
309
310impl ValidityVTable<ByteBool> for ByteBool {
311    fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
312        Ok(ByteBoolArrayExt::validity(&array))
313    }
314}
315
316impl OperationsVTable<ByteBool> for ByteBool {
317    fn scalar_at(
318        array: ArrayView<'_, ByteBool>,
319        index: usize,
320        _ctx: &mut ExecutionCtx,
321    ) -> VortexResult<Scalar> {
322        Ok(Scalar::bool(
323            array.buffer.as_host()[index] == 1,
324            array.dtype().nullability(),
325        ))
326    }
327}
328
329impl From<Vec<bool>> for ByteBoolData {
330    fn from(value: Vec<bool>) -> Self {
331        Self::from_vec(value, Validity::AllValid)
332    }
333}
334
335impl From<Vec<Option<bool>>> for ByteBoolData {
336    fn from(value: Vec<Option<bool>>) -> Self {
337        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
338
339        // This doesn't reallocate, and the compiler even vectorizes it
340        let data = value.into_iter().map(Option::unwrap_or_default).collect();
341
342        Self::from_vec(data, validity)
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use vortex_array::ArrayContext;
349    use vortex_array::IntoArray;
350    use vortex_array::LEGACY_SESSION;
351    use vortex_array::VortexSessionExecute;
352    use vortex_array::assert_arrays_eq;
353    use vortex_array::serde::SerializeOptions;
354    use vortex_array::serde::SerializedArray;
355    use vortex_array::session::ArraySession;
356    use vortex_array::session::ArraySessionExt;
357    use vortex_buffer::ByteBufferMut;
358    use vortex_session::VortexSession;
359    use vortex_session::registry::ReadContext;
360
361    use super::*;
362
363    #[test]
364    fn test_validity_construction() {
365        let v = vec![true, false];
366        let v_len = v.len();
367
368        let arr = ByteBool::from_vec(v, Validity::AllValid);
369        assert_eq!(v_len, arr.len());
370
371        let mut ctx = LEGACY_SESSION.create_execution_ctx();
372        for idx in 0..arr.len() {
373            assert!(arr.is_valid(idx, &mut ctx).unwrap());
374        }
375
376        let v = vec![Some(true), None, Some(false)];
377        let arr = ByteBool::from_option_vec(v);
378        assert!(arr.is_valid(0, &mut ctx).unwrap());
379        assert!(!arr.is_valid(1, &mut ctx).unwrap());
380        assert!(arr.is_valid(2, &mut ctx).unwrap());
381        assert_eq!(arr.len(), 3);
382
383        let v: Vec<Option<bool>> = vec![None, None];
384        let v_len = v.len();
385
386        let arr = ByteBool::from_option_vec(v);
387        assert_eq!(v_len, arr.len());
388
389        for idx in 0..arr.len() {
390            assert!(!arr.is_valid(idx, &mut ctx).unwrap());
391        }
392        assert_eq!(arr.len(), 2);
393    }
394
395    #[test]
396    fn test_nullable_bytebool_serde_roundtrip() {
397        let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
398        let dtype = array.dtype().clone();
399        let len = array.len();
400        let session = VortexSession::empty().with::<ArraySession>();
401        session.arrays().register(ByteBool);
402
403        let ctx = ArrayContext::empty();
404        let serialized = array
405            .clone()
406            .into_array()
407            .serialize(&ctx, &session, &SerializeOptions::default())
408            .unwrap();
409
410        let mut concat = ByteBufferMut::empty();
411        for buf in serialized {
412            concat.extend_from_slice(buf.as_ref());
413        }
414
415        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
416        let decoded = parts
417            .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
418            .unwrap();
419
420        assert_arrays_eq!(decoded, array);
421    }
422}