Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArraySlots;
16use vortex_array::ArrayView;
17use vortex_array::ExecutionCtx;
18use vortex_array::ExecutionResult;
19use vortex_array::IntoArray;
20use vortex_array::Precision;
21use vortex_array::TypedArrayRef;
22use vortex_array::arrays::BoolArray;
23use vortex_array::buffer::BufferHandle;
24use vortex_array::dtype::DType;
25use vortex_array::scalar::Scalar;
26use vortex_array::serde::ArrayChildren;
27use vortex_array::validity::Validity;
28use vortex_array::vtable::OperationsVTable;
29use vortex_array::vtable::VTable;
30use vortex_array::vtable::ValidityVTable;
31use vortex_array::vtable::child_to_validity;
32use vortex_array::vtable::validity_to_child;
33use vortex_buffer::BitBufferMut;
34use vortex_buffer::ByteBuffer;
35use vortex_error::VortexResult;
36use vortex_error::vortex_bail;
37use vortex_error::vortex_ensure;
38use vortex_error::vortex_panic;
39use vortex_session::VortexSession;
40use vortex_session::registry::CachedId;
41
42use crate::kernel::PARENT_KERNELS;
43
44/// A [`ByteBool`]-encoded Vortex array.
45pub type ByteBoolArray = Array<ByteBool>;
46
47impl ArrayHash for ByteBoolData {
48    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
49        self.buffer.array_hash(state, precision);
50    }
51}
52
53impl ArrayEq for ByteBoolData {
54    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
55        self.buffer.array_eq(&other.buffer, precision)
56    }
57}
58
59impl VTable for ByteBool {
60    type TypedArrayData = ByteBoolData;
61
62    type OperationsVTable = Self;
63    type ValidityVTable = Self;
64
65    fn id(&self) -> ArrayId {
66        static ID: CachedId = CachedId::new("vortex.bytebool");
67        *ID
68    }
69
70    fn validate(
71        &self,
72        data: &Self::TypedArrayData,
73        dtype: &DType,
74        len: usize,
75        slots: &[Option<ArrayRef>],
76    ) -> VortexResult<()> {
77        let validity = child_to_validity(slots[VALIDITY_SLOT].as_ref(), dtype.nullability());
78        ByteBoolData::validate(data.buffer(), &validity, dtype, len)
79    }
80
81    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
82        1
83    }
84
85    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
86        match idx {
87            0 => array.buffer().clone(),
88            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
89        }
90    }
91
92    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
93        match idx {
94            0 => Some("values".to_string()),
95            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
96        }
97    }
98
99    fn serialize(
100        _array: ArrayView<'_, Self>,
101        _session: &VortexSession,
102    ) -> VortexResult<Option<Vec<u8>>> {
103        Ok(Some(vec![]))
104    }
105
106    fn deserialize(
107        &self,
108        dtype: &DType,
109        len: usize,
110        metadata: &[u8],
111        buffers: &[BufferHandle],
112        children: &dyn ArrayChildren,
113        _session: &VortexSession,
114    ) -> VortexResult<ArrayParts<Self>> {
115        if !metadata.is_empty() {
116            vortex_bail!(
117                "ByteBoolArray expects empty metadata, got {} bytes",
118                metadata.len()
119            );
120        }
121        let validity = if children.is_empty() {
122            Validity::from(dtype.nullability())
123        } else if children.len() == 1 {
124            let validity = children.get(0, &Validity::DTYPE, len)?;
125            Validity::Array(validity)
126        } else {
127            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
128        };
129
130        if buffers.len() != 1 {
131            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
132        }
133        let buffer = buffers[0].clone();
134
135        let data = ByteBoolData::new(buffer);
136        let slots = ByteBoolData::make_slots(&validity, len);
137        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
138    }
139
140    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
141        SLOT_NAMES[idx].to_string()
142    }
143
144    fn reduce_parent(
145        array: ArrayView<'_, Self>,
146        parent: &ArrayRef,
147        child_idx: usize,
148    ) -> VortexResult<Option<ArrayRef>> {
149        crate::rules::RULES.evaluate(array, parent, child_idx)
150    }
151
152    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
153        // convert truthy values to set/unset bits
154        let boolean_buffer = BitBufferMut::from(array.truthy_bytes()).freeze();
155        let validity = array.validity()?;
156        Ok(ExecutionResult::done(
157            BoolArray::new(boolean_buffer, validity).into_array(),
158        ))
159    }
160
161    fn execute_parent(
162        array: ArrayView<'_, Self>,
163        parent: &ArrayRef,
164        child_idx: usize,
165        ctx: &mut ExecutionCtx,
166    ) -> VortexResult<Option<ArrayRef>> {
167        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
168    }
169}
170
171/// The validity bitmap indicating which elements are non-null.
172pub(super) const VALIDITY_SLOT: usize = 0;
173pub(super) const NUM_SLOTS: usize = 1;
174pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
175
176#[derive(Clone, Debug)]
177pub struct ByteBoolData {
178    buffer: BufferHandle,
179}
180
181impl Display for ByteBoolData {
182    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
183        Ok(())
184    }
185}
186
187pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
188    fn validity(&self) -> Validity {
189        child_to_validity(
190            self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
191            self.as_ref().dtype().nullability(),
192        )
193    }
194}
195
196impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
197
198#[derive(Clone, Debug)]
199pub struct ByteBool;
200
201impl ByteBool {
202    pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
203        if let Some(len) = validity.maybe_len() {
204            assert_eq!(
205                buffer.len(),
206                len,
207                "ByteBool validity and bytes must have same length"
208            );
209        }
210        let dtype = DType::Bool(validity.nullability());
211
212        let slots = ByteBoolData::make_slots(&validity, buffer.len());
213        let data = ByteBoolData::new(buffer);
214        let len = data.len();
215        unsafe {
216            Array::from_parts_unchecked(
217                ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
218            )
219        }
220    }
221
222    /// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
223    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
224        let validity = validity.into();
225        // NOTE: this will not cause allocation on release builds
226        let bytes: Vec<u8> = data.into_iter().map(|b| b as u8).collect();
227        let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
228        ByteBool::new(handle, validity)
229    }
230
231    /// Construct a [`ByteBoolArray`] from optional bools.
232    pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
233        let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
234        // NOTE: this will not cause allocation on release builds
235        let bytes: Vec<u8> = data
236            .into_iter()
237            .map(|b| b.unwrap_or_default() as u8)
238            .collect();
239        let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
240        ByteBool::new(handle, validity)
241    }
242}
243
244impl ByteBoolData {
245    pub fn validate(
246        buffer: &BufferHandle,
247        validity: &Validity,
248        dtype: &DType,
249        len: usize,
250    ) -> VortexResult<()> {
251        let expected_dtype = DType::Bool(validity.nullability());
252        vortex_ensure!(
253            dtype == &expected_dtype,
254            "expected dtype {expected_dtype}, got {dtype}"
255        );
256        vortex_ensure!(
257            buffer.len() == len,
258            "expected len {len}, got {}",
259            buffer.len()
260        );
261        if let Some(vlen) = validity.maybe_len() {
262            vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
263        }
264        Ok(())
265    }
266
267    fn make_slots(validity: &Validity, len: usize) -> ArraySlots {
268        vec![validity_to_child(validity, len)].into()
269    }
270
271    pub fn new(buffer: BufferHandle) -> Self {
272        Self { buffer }
273    }
274
275    /// Returns the number of elements in the array.
276    pub fn len(&self) -> usize {
277        self.buffer.len()
278    }
279
280    /// Returns `true` if the array contains no elements.
281    pub fn is_empty(&self) -> bool {
282        self.buffer.len() == 0
283    }
284
285    pub fn buffer(&self) -> &BufferHandle {
286        &self.buffer
287    }
288
289    /// Get access to the underlying 8-bit truthy values.
290    ///
291    /// The zero byte indicates `false`, and any non-zero byte is a `true`.
292    pub fn truthy_bytes(&self) -> &[u8] {
293        self.buffer().as_host().as_slice()
294    }
295}
296
297impl ValidityVTable<ByteBool> for ByteBool {
298    fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
299        Ok(ByteBoolArrayExt::validity(&array))
300    }
301}
302
303impl OperationsVTable<ByteBool> for ByteBool {
304    fn scalar_at(
305        array: ArrayView<'_, ByteBool>,
306        index: usize,
307        _ctx: &mut ExecutionCtx,
308    ) -> VortexResult<Scalar> {
309        Ok(Scalar::bool(
310            array.buffer.as_host()[index] == 1,
311            array.dtype().nullability(),
312        ))
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use vortex_array::ArrayContext;
319    use vortex_array::IntoArray;
320    use vortex_array::LEGACY_SESSION;
321    use vortex_array::VortexSessionExecute;
322    use vortex_array::assert_arrays_eq;
323    use vortex_array::serde::SerializeOptions;
324    use vortex_array::serde::SerializedArray;
325    use vortex_array::session::ArraySession;
326    use vortex_array::session::ArraySessionExt;
327    use vortex_buffer::ByteBufferMut;
328    use vortex_session::VortexSession;
329    use vortex_session::registry::ReadContext;
330
331    use super::*;
332
333    #[test]
334    fn test_validity_construction() {
335        let v = vec![true, false];
336        let v_len = v.len();
337
338        let arr = ByteBool::from_vec(v, Validity::AllValid);
339        assert_eq!(v_len, arr.len());
340
341        let mut ctx = LEGACY_SESSION.create_execution_ctx();
342        for idx in 0..arr.len() {
343            assert!(arr.is_valid(idx, &mut ctx).unwrap());
344        }
345
346        let v = vec![Some(true), None, Some(false)];
347        let arr = ByteBool::from_option_vec(v);
348        assert!(arr.is_valid(0, &mut ctx).unwrap());
349        assert!(!arr.is_valid(1, &mut ctx).unwrap());
350        assert!(arr.is_valid(2, &mut ctx).unwrap());
351        assert_eq!(arr.len(), 3);
352
353        let v: Vec<Option<bool>> = vec![None, None];
354        let v_len = v.len();
355
356        let arr = ByteBool::from_option_vec(v);
357        assert_eq!(v_len, arr.len());
358
359        for idx in 0..arr.len() {
360            assert!(!arr.is_valid(idx, &mut ctx).unwrap());
361        }
362        assert_eq!(arr.len(), 2);
363    }
364
365    #[test]
366    fn test_nullable_bytebool_serde_roundtrip() {
367        let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
368        let dtype = array.dtype().clone();
369        let len = array.len();
370        let session = VortexSession::empty().with::<ArraySession>();
371        session.arrays().register(ByteBool);
372
373        let ctx = ArrayContext::empty();
374        let serialized = array
375            .clone()
376            .into_array()
377            .serialize(&ctx, &session, &SerializeOptions::default())
378            .unwrap();
379
380        let mut concat = ByteBufferMut::empty();
381        for buf in serialized {
382            concat.extend_from_slice(buf.as_ref());
383        }
384
385        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
386        let decoded = parts
387            .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
388            .unwrap();
389
390        assert_arrays_eq!(decoded, array);
391    }
392}