Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_array::ArrayBufferVisitor;
8use vortex_array::ArrayChildVisitor;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::EmptyMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::arrays::BoolArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::dtype::DType;
19use vortex_array::scalar::Scalar;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::BaseArrayVTable;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityHelper;
30use vortex_array::vtable::ValidityVTableFromValidityHelper;
31use vortex_array::vtable::VisitorVTable;
32use vortex_array::vtable::validity_nchildren;
33use vortex_buffer::BitBuffer;
34use vortex_buffer::ByteBuffer;
35use vortex_error::VortexExpect as _;
36use vortex_error::VortexResult;
37use vortex_error::vortex_bail;
38use vortex_error::vortex_ensure;
39use vortex_error::vortex_panic;
40use vortex_session::VortexSession;
41
42use crate::kernel::PARENT_KERNELS;
43
44vtable!(ByteBool);
45
46impl VTable for ByteBoolVTable {
47    type Array = ByteBoolArray;
48
49    type Metadata = EmptyMetadata;
50
51    type ArrayVTable = Self;
52    type OperationsVTable = Self;
53    type ValidityVTable = ValidityVTableFromValidityHelper;
54    type VisitorVTable = Self;
55
56    fn id(_array: &Self::Array) -> ArrayId {
57        Self::ID
58    }
59
60    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
61        Ok(EmptyMetadata)
62    }
63
64    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
65        Ok(Some(vec![]))
66    }
67
68    fn deserialize(
69        _bytes: &[u8],
70        _dtype: &DType,
71        _len: usize,
72        _buffers: &[BufferHandle],
73        _session: &VortexSession,
74    ) -> VortexResult<Self::Metadata> {
75        Ok(EmptyMetadata)
76    }
77
78    fn build(
79        dtype: &DType,
80        len: usize,
81        _metadata: &Self::Metadata,
82        buffers: &[BufferHandle],
83        children: &dyn ArrayChildren,
84    ) -> VortexResult<ByteBoolArray> {
85        let validity = if children.is_empty() {
86            Validity::from(dtype.nullability())
87        } else if children.len() == 1 {
88            let validity = children.get(0, &Validity::DTYPE, len)?;
89            Validity::Array(validity)
90        } else {
91            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
92        };
93
94        if buffers.len() != 1 {
95            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
96        }
97        let buffer = buffers[0].clone();
98
99        Ok(ByteBoolArray::new(buffer, validity))
100    }
101
102    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
103        vortex_ensure!(
104            children.len() <= 1,
105            "ByteBoolArray expects at most 1 child (validity), got {}",
106            children.len()
107        );
108
109        array.validity = if children.is_empty() {
110            Validity::from(array.dtype.nullability())
111        } else {
112            Validity::Array(children.into_iter().next().vortex_expect("checked"))
113        };
114
115        Ok(())
116    }
117
118    fn reduce_parent(
119        array: &Self::Array,
120        parent: &ArrayRef,
121        child_idx: usize,
122    ) -> VortexResult<Option<ArrayRef>> {
123        crate::rules::RULES.evaluate(array, parent, child_idx)
124    }
125
126    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
127        let boolean_buffer = BitBuffer::from(array.as_slice());
128        let validity = array.validity().clone();
129        Ok(BoolArray::new(boolean_buffer, validity).into_array())
130    }
131
132    fn execute_parent(
133        array: &Self::Array,
134        parent: &ArrayRef,
135        child_idx: usize,
136        ctx: &mut ExecutionCtx,
137    ) -> VortexResult<Option<ArrayRef>> {
138        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
139    }
140}
141
142#[derive(Clone, Debug)]
143pub struct ByteBoolArray {
144    dtype: DType,
145    buffer: BufferHandle,
146    validity: Validity,
147    stats_set: ArrayStats,
148}
149
150#[derive(Debug)]
151pub struct ByteBoolVTable;
152
153impl ByteBoolVTable {
154    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
155}
156
157impl ByteBoolArray {
158    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
159        let length = buffer.len();
160        if let Some(vlen) = validity.maybe_len()
161            && length != vlen
162        {
163            vortex_panic!(
164                "Buffer length ({}) does not match validity length ({})",
165                length,
166                vlen
167            );
168        }
169        Self {
170            dtype: DType::Bool(validity.nullability()),
171            buffer,
172            validity,
173            stats_set: Default::default(),
174        }
175    }
176
177    // TODO(ngates): deprecate construction from vec
178    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
179        let validity = validity.into();
180        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
181        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
182        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
183    }
184
185    pub fn buffer(&self) -> &BufferHandle {
186        &self.buffer
187    }
188
189    pub fn as_slice(&self) -> &[bool] {
190        // Safety: The internal buffer contains byte-sized bools
191        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
192    }
193}
194
195impl ValidityHelper for ByteBoolArray {
196    fn validity(&self) -> &Validity {
197        &self.validity
198    }
199}
200
201impl BaseArrayVTable<ByteBoolVTable> for ByteBoolVTable {
202    fn len(array: &ByteBoolArray) -> usize {
203        array.buffer.len()
204    }
205
206    fn dtype(array: &ByteBoolArray) -> &DType {
207        &array.dtype
208    }
209
210    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
211        array.stats_set.to_ref(array.as_ref())
212    }
213
214    fn array_hash<H: std::hash::Hasher>(
215        array: &ByteBoolArray,
216        state: &mut H,
217        precision: Precision,
218    ) {
219        array.dtype.hash(state);
220        array.buffer.array_hash(state, precision);
221        array.validity.array_hash(state, precision);
222    }
223
224    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
225        array.dtype == other.dtype
226            && array.buffer.array_eq(&other.buffer, precision)
227            && array.validity.array_eq(&other.validity, precision)
228    }
229}
230
231impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
232    fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
233        Ok(Scalar::bool(
234            array.buffer.as_host()[index] == 1,
235            array.dtype().nullability(),
236        ))
237    }
238}
239
240impl VisitorVTable<ByteBoolVTable> for ByteBoolVTable {
241    fn visit_buffers(array: &ByteBoolArray, visitor: &mut dyn ArrayBufferVisitor) {
242        visitor.visit_buffer_handle("values", array.buffer());
243    }
244
245    fn nbuffers(_array: &ByteBoolArray) -> usize {
246        1
247    }
248
249    fn visit_children(array: &ByteBoolArray, visitor: &mut dyn ArrayChildVisitor) {
250        visitor.visit_validity(array.validity(), array.len());
251    }
252
253    fn nchildren(array: &ByteBoolArray) -> usize {
254        validity_nchildren(array.validity())
255    }
256}
257
258impl From<Vec<bool>> for ByteBoolArray {
259    fn from(value: Vec<bool>) -> Self {
260        Self::from_vec(value, Validity::AllValid)
261    }
262}
263
264impl From<Vec<Option<bool>>> for ByteBoolArray {
265    fn from(value: Vec<Option<bool>>) -> Self {
266        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
267
268        // This doesn't reallocate, and the compiler even vectorizes it
269        let data = value.into_iter().map(Option::unwrap_or_default).collect();
270
271        Self::from_vec(data, validity)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_validity_construction() {
281        let v = vec![true, false];
282        let v_len = v.len();
283
284        let arr = ByteBoolArray::from(v);
285        assert_eq!(v_len, arr.len());
286
287        for idx in 0..arr.len() {
288            assert!(arr.is_valid(idx).unwrap());
289        }
290
291        let v = vec![Some(true), None, Some(false)];
292        let arr = ByteBoolArray::from(v);
293        assert!(arr.is_valid(0).unwrap());
294        assert!(!arr.is_valid(1).unwrap());
295        assert!(arr.is_valid(2).unwrap());
296        assert_eq!(arr.len(), 3);
297
298        let v: Vec<Option<bool>> = vec![None, None];
299        let v_len = v.len();
300
301        let arr = ByteBoolArray::from(v);
302        assert_eq!(v_len, arr.len());
303
304        for idx in 0..arr.len() {
305            assert!(!arr.is_valid(idx).unwrap());
306        }
307        assert_eq!(arr.len(), 2);
308    }
309}