Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_array::ArrayEq;
8use vortex_array::ArrayHash;
9use vortex_array::ArrayRef;
10use vortex_array::EmptyMetadata;
11use vortex_array::ExecutionCtx;
12use vortex_array::IntoArray;
13use vortex_array::Precision;
14use vortex_array::arrays::BoolArray;
15use vortex_array::buffer::BufferHandle;
16use vortex_array::dtype::DType;
17use vortex_array::scalar::Scalar;
18use vortex_array::serde::ArrayChildren;
19use vortex_array::stats::ArrayStats;
20use vortex_array::stats::StatsSetRef;
21use vortex_array::validity::Validity;
22use vortex_array::vtable;
23use vortex_array::vtable::ArrayId;
24use vortex_array::vtable::OperationsVTable;
25use vortex_array::vtable::VTable;
26use vortex_array::vtable::ValidityHelper;
27use vortex_array::vtable::ValidityVTableFromValidityHelper;
28use vortex_array::vtable::validity_nchildren;
29use vortex_array::vtable::validity_to_child;
30use vortex_buffer::BitBuffer;
31use vortex_buffer::ByteBuffer;
32use vortex_error::VortexExpect as _;
33use vortex_error::VortexResult;
34use vortex_error::vortex_bail;
35use vortex_error::vortex_ensure;
36use vortex_error::vortex_panic;
37use vortex_session::VortexSession;
38
39use crate::kernel::PARENT_KERNELS;
40
41vtable!(ByteBool);
42
43impl VTable for ByteBoolVTable {
44    type Array = ByteBoolArray;
45
46    type Metadata = EmptyMetadata;
47    type OperationsVTable = Self;
48    type ValidityVTable = ValidityVTableFromValidityHelper;
49
50    fn id(_array: &Self::Array) -> ArrayId {
51        Self::ID
52    }
53
54    fn len(array: &ByteBoolArray) -> usize {
55        array.buffer.len()
56    }
57
58    fn dtype(array: &ByteBoolArray) -> &DType {
59        &array.dtype
60    }
61
62    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
63        array.stats_set.to_ref(array.as_ref())
64    }
65
66    fn array_hash<H: std::hash::Hasher>(
67        array: &ByteBoolArray,
68        state: &mut H,
69        precision: Precision,
70    ) {
71        array.dtype.hash(state);
72        array.buffer.array_hash(state, precision);
73        array.validity.array_hash(state, precision);
74    }
75
76    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
77        array.dtype == other.dtype
78            && array.buffer.array_eq(&other.buffer, precision)
79            && array.validity.array_eq(&other.validity, precision)
80    }
81
82    fn nbuffers(_array: &ByteBoolArray) -> usize {
83        1
84    }
85
86    fn buffer(array: &ByteBoolArray, idx: usize) -> BufferHandle {
87        match idx {
88            0 => array.buffer().clone(),
89            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
90        }
91    }
92
93    fn buffer_name(_array: &ByteBoolArray, idx: usize) -> Option<String> {
94        match idx {
95            0 => Some("values".to_string()),
96            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
97        }
98    }
99
100    fn nchildren(array: &ByteBoolArray) -> usize {
101        validity_nchildren(array.validity())
102    }
103
104    fn child(array: &ByteBoolArray, idx: usize) -> ArrayRef {
105        match idx {
106            0 => validity_to_child(array.validity(), array.len())
107                .vortex_expect("ByteBoolArray validity child out of bounds"),
108            _ => vortex_panic!("ByteBoolArray child index {idx} out of bounds"),
109        }
110    }
111
112    fn child_name(_array: &ByteBoolArray, idx: usize) -> String {
113        match idx {
114            0 => "validity".to_string(),
115            _ => vortex_panic!("ByteBoolArray child_name index {idx} out of bounds"),
116        }
117    }
118
119    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
120        Ok(EmptyMetadata)
121    }
122
123    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
124        Ok(Some(vec![]))
125    }
126
127    fn deserialize(
128        _bytes: &[u8],
129        _dtype: &DType,
130        _len: usize,
131        _buffers: &[BufferHandle],
132        _session: &VortexSession,
133    ) -> VortexResult<Self::Metadata> {
134        Ok(EmptyMetadata)
135    }
136
137    fn build(
138        dtype: &DType,
139        len: usize,
140        _metadata: &Self::Metadata,
141        buffers: &[BufferHandle],
142        children: &dyn ArrayChildren,
143    ) -> VortexResult<ByteBoolArray> {
144        let validity = if children.is_empty() {
145            Validity::from(dtype.nullability())
146        } else if children.len() == 1 {
147            let validity = children.get(0, &Validity::DTYPE, len)?;
148            Validity::Array(validity)
149        } else {
150            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
151        };
152
153        if buffers.len() != 1 {
154            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
155        }
156        let buffer = buffers[0].clone();
157
158        Ok(ByteBoolArray::new(buffer, validity))
159    }
160
161    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
162        vortex_ensure!(
163            children.len() <= 1,
164            "ByteBoolArray expects at most 1 child (validity), got {}",
165            children.len()
166        );
167
168        array.validity = if children.is_empty() {
169            Validity::from(array.dtype.nullability())
170        } else {
171            Validity::Array(children.into_iter().next().vortex_expect("checked"))
172        };
173
174        Ok(())
175    }
176
177    fn reduce_parent(
178        array: &Self::Array,
179        parent: &ArrayRef,
180        child_idx: usize,
181    ) -> VortexResult<Option<ArrayRef>> {
182        crate::rules::RULES.evaluate(array, parent, child_idx)
183    }
184
185    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
186        let boolean_buffer = BitBuffer::from(array.as_slice());
187        let validity = array.validity().clone();
188        Ok(BoolArray::new(boolean_buffer, validity).into_array())
189    }
190
191    fn execute_parent(
192        array: &Self::Array,
193        parent: &ArrayRef,
194        child_idx: usize,
195        ctx: &mut ExecutionCtx,
196    ) -> VortexResult<Option<ArrayRef>> {
197        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
198    }
199}
200
201#[derive(Clone, Debug)]
202pub struct ByteBoolArray {
203    dtype: DType,
204    buffer: BufferHandle,
205    validity: Validity,
206    stats_set: ArrayStats,
207}
208
209#[derive(Debug)]
210pub struct ByteBoolVTable;
211
212impl ByteBoolVTable {
213    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
214}
215
216impl ByteBoolArray {
217    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
218        let length = buffer.len();
219        if let Some(vlen) = validity.maybe_len()
220            && length != vlen
221        {
222            vortex_panic!(
223                "Buffer length ({}) does not match validity length ({})",
224                length,
225                vlen
226            );
227        }
228        Self {
229            dtype: DType::Bool(validity.nullability()),
230            buffer,
231            validity,
232            stats_set: Default::default(),
233        }
234    }
235
236    // TODO(ngates): deprecate construction from vec
237    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
238        let validity = validity.into();
239        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
240        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
241        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
242    }
243
244    pub fn buffer(&self) -> &BufferHandle {
245        &self.buffer
246    }
247
248    pub fn as_slice(&self) -> &[bool] {
249        // Safety: The internal buffer contains byte-sized bools
250        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
251    }
252}
253
254impl ValidityHelper for ByteBoolArray {
255    fn validity(&self) -> &Validity {
256        &self.validity
257    }
258}
259
260impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
261    fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
262        Ok(Scalar::bool(
263            array.buffer.as_host()[index] == 1,
264            array.dtype().nullability(),
265        ))
266    }
267}
268
269impl From<Vec<bool>> for ByteBoolArray {
270    fn from(value: Vec<bool>) -> Self {
271        Self::from_vec(value, Validity::AllValid)
272    }
273}
274
275impl From<Vec<Option<bool>>> for ByteBoolArray {
276    fn from(value: Vec<Option<bool>>) -> Self {
277        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
278
279        // This doesn't reallocate, and the compiler even vectorizes it
280        let data = value.into_iter().map(Option::unwrap_or_default).collect();
281
282        Self::from_vec(data, validity)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_validity_construction() {
292        let v = vec![true, false];
293        let v_len = v.len();
294
295        let arr = ByteBoolArray::from(v);
296        assert_eq!(v_len, arr.len());
297
298        for idx in 0..arr.len() {
299            assert!(arr.is_valid(idx).unwrap());
300        }
301
302        let v = vec![Some(true), None, Some(false)];
303        let arr = ByteBoolArray::from(v);
304        assert!(arr.is_valid(0).unwrap());
305        assert!(!arr.is_valid(1).unwrap());
306        assert!(arr.is_valid(2).unwrap());
307        assert_eq!(arr.len(), 3);
308
309        let v: Vec<Option<bool>> = vec![None, None];
310        let v_len = v.len();
311
312        let arr = ByteBoolArray::from(v);
313        assert_eq!(v_len, arr.len());
314
315        for idx in 0..arr.len() {
316            assert!(!arr.is_valid(idx).unwrap());
317        }
318        assert_eq!(arr.len(), 2);
319    }
320}