Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArrayView;
16use vortex_array::ExecutionCtx;
17use vortex_array::ExecutionResult;
18use vortex_array::IntoArray;
19use vortex_array::Precision;
20use vortex_array::TypedArrayRef;
21use vortex_array::arrays::BoolArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::scalar::Scalar;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityVTable;
30use vortex_array::vtable::child_to_validity;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBufferMut;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_session::VortexSession;
39use vortex_session::registry::CachedId;
40
41use crate::kernel::PARENT_KERNELS;
42
43/// A [`ByteBool`]-encoded Vortex array.
44pub type ByteBoolArray = Array<ByteBool>;
45
46impl ArrayHash for ByteBoolData {
47    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
48        self.buffer.array_hash(state, precision);
49    }
50}
51
52impl ArrayEq for ByteBoolData {
53    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
54        self.buffer.array_eq(&other.buffer, precision)
55    }
56}
57
58impl VTable for ByteBool {
59    type ArrayData = ByteBoolData;
60
61    type OperationsVTable = Self;
62    type ValidityVTable = Self;
63
64    fn id(&self) -> ArrayId {
65        static ID: CachedId = CachedId::new("vortex.bytebool");
66        *ID
67    }
68
69    fn validate(
70        &self,
71        data: &Self::ArrayData,
72        dtype: &DType,
73        len: usize,
74        slots: &[Option<ArrayRef>],
75    ) -> VortexResult<()> {
76        let validity = child_to_validity(slots[VALIDITY_SLOT].as_ref(), dtype.nullability());
77        ByteBoolData::validate(data.buffer(), &validity, dtype, len)
78    }
79
80    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
81        1
82    }
83
84    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
85        match idx {
86            0 => array.buffer().clone(),
87            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
88        }
89    }
90
91    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
92        match idx {
93            0 => Some("values".to_string()),
94            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
95        }
96    }
97
98    fn serialize(
99        _array: ArrayView<'_, Self>,
100        _session: &VortexSession,
101    ) -> VortexResult<Option<Vec<u8>>> {
102        Ok(Some(vec![]))
103    }
104
105    fn deserialize(
106        &self,
107        dtype: &DType,
108        len: usize,
109        metadata: &[u8],
110        buffers: &[BufferHandle],
111        children: &dyn ArrayChildren,
112        _session: &VortexSession,
113    ) -> VortexResult<ArrayParts<Self>> {
114        if !metadata.is_empty() {
115            vortex_bail!(
116                "ByteBoolArray expects empty metadata, got {} bytes",
117                metadata.len()
118            );
119        }
120        let validity = if children.is_empty() {
121            Validity::from(dtype.nullability())
122        } else if children.len() == 1 {
123            let validity = children.get(0, &Validity::DTYPE, len)?;
124            Validity::Array(validity)
125        } else {
126            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
127        };
128
129        if buffers.len() != 1 {
130            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
131        }
132        let buffer = buffers[0].clone();
133
134        let data = ByteBoolData::new(buffer);
135        let slots = ByteBoolData::make_slots(&validity, len);
136        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
137    }
138
139    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
140        SLOT_NAMES[idx].to_string()
141    }
142
143    fn reduce_parent(
144        array: ArrayView<'_, Self>,
145        parent: &ArrayRef,
146        child_idx: usize,
147    ) -> VortexResult<Option<ArrayRef>> {
148        crate::rules::RULES.evaluate(array, parent, child_idx)
149    }
150
151    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
152        // convert truthy values to set/unset bits
153        let boolean_buffer = BitBufferMut::from(array.truthy_bytes()).freeze();
154        let validity = array.validity()?;
155        Ok(ExecutionResult::done(
156            BoolArray::new(boolean_buffer, validity).into_array(),
157        ))
158    }
159
160    fn execute_parent(
161        array: ArrayView<'_, Self>,
162        parent: &ArrayRef,
163        child_idx: usize,
164        ctx: &mut ExecutionCtx,
165    ) -> VortexResult<Option<ArrayRef>> {
166        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
167    }
168}
169
170/// The validity bitmap indicating which elements are non-null.
171pub(super) const VALIDITY_SLOT: usize = 0;
172pub(super) const NUM_SLOTS: usize = 1;
173pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
174
175#[derive(Clone, Debug)]
176pub struct ByteBoolData {
177    buffer: BufferHandle,
178}
179
180impl Display for ByteBoolData {
181    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
182        Ok(())
183    }
184}
185
186pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
187    fn validity(&self) -> Validity {
188        child_to_validity(
189            self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
190            self.as_ref().dtype().nullability(),
191        )
192    }
193}
194
195impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
196
197#[derive(Clone, Debug)]
198pub struct ByteBool;
199
200impl ByteBool {
201    pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
202        if let Some(len) = validity.maybe_len() {
203            assert_eq!(
204                buffer.len(),
205                len,
206                "ByteBool validity and bytes must have same length"
207            );
208        }
209        let dtype = DType::Bool(validity.nullability());
210
211        let slots = ByteBoolData::make_slots(&validity, buffer.len());
212        let data = ByteBoolData::new(buffer);
213        let len = data.len();
214        unsafe {
215            Array::from_parts_unchecked(
216                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
217            )
218        }
219    }
220
221    /// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
222    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
223        let validity = validity.into();
224        // NOTE: this will not cause allocation on release builds
225        let bytes: Vec<u8> = data.into_iter().map(|b| b as u8).collect();
226        let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
227        ByteBool::new(handle, validity)
228    }
229
230    /// Construct a [`ByteBoolArray`] from optional bools.
231    pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
232        let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
233        // NOTE: this will not cause allocation on release builds
234        let bytes: Vec<u8> = data
235            .into_iter()
236            .map(|b| b.unwrap_or_default() as u8)
237            .collect();
238        let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
239        ByteBool::new(handle, validity)
240    }
241}
242
243impl ByteBoolData {
244    pub fn validate(
245        buffer: &BufferHandle,
246        validity: &Validity,
247        dtype: &DType,
248        len: usize,
249    ) -> VortexResult<()> {
250        let expected_dtype = DType::Bool(validity.nullability());
251        vortex_ensure!(
252            dtype == &expected_dtype,
253            "expected dtype {expected_dtype}, got {dtype}"
254        );
255        vortex_ensure!(
256            buffer.len() == len,
257            "expected len {len}, got {}",
258            buffer.len()
259        );
260        if let Some(vlen) = validity.maybe_len() {
261            vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
262        }
263        Ok(())
264    }
265
266    fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
267        vec![validity_to_child(validity, len)]
268    }
269
270    pub fn new(buffer: BufferHandle) -> Self {
271        Self { buffer }
272    }
273
274    /// Returns the number of elements in the array.
275    pub fn len(&self) -> usize {
276        self.buffer.len()
277    }
278
279    /// Returns `true` if the array contains no elements.
280    pub fn is_empty(&self) -> bool {
281        self.buffer.len() == 0
282    }
283
284    pub fn buffer(&self) -> &BufferHandle {
285        &self.buffer
286    }
287
288    /// Get access to the underlying 8-bit truthy values.
289    ///
290    /// The zero byte indicates `false`, and any non-zero byte is a `true`.
291    pub fn truthy_bytes(&self) -> &[u8] {
292        self.buffer().as_host().as_slice()
293    }
294}
295
296impl ValidityVTable<ByteBool> for ByteBool {
297    fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
298        Ok(ByteBoolArrayExt::validity(&array))
299    }
300}
301
302impl OperationsVTable<ByteBool> for ByteBool {
303    fn scalar_at(
304        array: ArrayView<'_, ByteBool>,
305        index: usize,
306        _ctx: &mut ExecutionCtx,
307    ) -> VortexResult<Scalar> {
308        Ok(Scalar::bool(
309            array.buffer.as_host()[index] == 1,
310            array.dtype().nullability(),
311        ))
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use vortex_array::ArrayContext;
318    use vortex_array::IntoArray;
319    use vortex_array::LEGACY_SESSION;
320    use vortex_array::VortexSessionExecute;
321    use vortex_array::assert_arrays_eq;
322    use vortex_array::serde::SerializeOptions;
323    use vortex_array::serde::SerializedArray;
324    use vortex_array::session::ArraySession;
325    use vortex_array::session::ArraySessionExt;
326    use vortex_buffer::ByteBufferMut;
327    use vortex_session::VortexSession;
328    use vortex_session::registry::ReadContext;
329
330    use super::*;
331
332    #[test]
333    fn test_validity_construction() {
334        let v = vec![true, false];
335        let v_len = v.len();
336
337        let arr = ByteBool::from_vec(v, Validity::AllValid);
338        assert_eq!(v_len, arr.len());
339
340        let mut ctx = LEGACY_SESSION.create_execution_ctx();
341        for idx in 0..arr.len() {
342            assert!(arr.is_valid(idx, &mut ctx).unwrap());
343        }
344
345        let v = vec![Some(true), None, Some(false)];
346        let arr = ByteBool::from_option_vec(v);
347        assert!(arr.is_valid(0, &mut ctx).unwrap());
348        assert!(!arr.is_valid(1, &mut ctx).unwrap());
349        assert!(arr.is_valid(2, &mut ctx).unwrap());
350        assert_eq!(arr.len(), 3);
351
352        let v: Vec<Option<bool>> = vec![None, None];
353        let v_len = v.len();
354
355        let arr = ByteBool::from_option_vec(v);
356        assert_eq!(v_len, arr.len());
357
358        for idx in 0..arr.len() {
359            assert!(!arr.is_valid(idx, &mut ctx).unwrap());
360        }
361        assert_eq!(arr.len(), 2);
362    }
363
364    #[test]
365    fn test_nullable_bytebool_serde_roundtrip() {
366        let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
367        let dtype = array.dtype().clone();
368        let len = array.len();
369        let session = VortexSession::empty().with::<ArraySession>();
370        session.arrays().register(ByteBool);
371
372        let ctx = ArrayContext::empty();
373        let serialized = array
374            .clone()
375            .into_array()
376            .serialize(&ctx, &session, &SerializeOptions::default())
377            .unwrap();
378
379        let mut concat = ByteBufferMut::empty();
380        for buf in serialized {
381            concat.extend_from_slice(buf.as_ref());
382        }
383
384        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
385        let decoded = parts
386            .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
387            .unwrap();
388
389        assert_arrays_eq!(decoded, array);
390    }
391}