Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_array::ArrayEq;
9use vortex_array::ArrayHash;
10use vortex_array::ArrayRef;
11use vortex_array::EmptyMetadata;
12use vortex_array::ExecutionCtx;
13use vortex_array::ExecutionResult;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::arrays::BoolArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::dtype::DType;
19use vortex_array::scalar::Scalar;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::Array;
26use vortex_array::vtable::ArrayId;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityHelper;
30use vortex_array::vtable::ValidityVTableFromValidityHelper;
31use vortex_array::vtable::validity_nchildren;
32use vortex_array::vtable::validity_to_child;
33use vortex_buffer::BitBuffer;
34use vortex_buffer::ByteBuffer;
35use vortex_error::VortexExpect as _;
36use vortex_error::VortexResult;
37use vortex_error::vortex_bail;
38use vortex_error::vortex_ensure;
39use vortex_error::vortex_panic;
40use vortex_session::VortexSession;
41
42use crate::kernel::PARENT_KERNELS;
43
44vtable!(ByteBool);
45
46impl VTable for ByteBool {
47    type Array = ByteBoolArray;
48
49    type Metadata = EmptyMetadata;
50    type OperationsVTable = Self;
51    type ValidityVTable = ValidityVTableFromValidityHelper;
52
53    fn vtable(_array: &Self::Array) -> &Self {
54        &ByteBool
55    }
56
57    fn id(&self) -> ArrayId {
58        Self::ID
59    }
60
61    fn len(array: &ByteBoolArray) -> usize {
62        array.buffer.len()
63    }
64
65    fn dtype(array: &ByteBoolArray) -> &DType {
66        &array.dtype
67    }
68
69    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
70        array.stats_set.to_ref(array.as_ref())
71    }
72
73    fn array_hash<H: std::hash::Hasher>(
74        array: &ByteBoolArray,
75        state: &mut H,
76        precision: Precision,
77    ) {
78        array.dtype.hash(state);
79        array.buffer.array_hash(state, precision);
80        array.validity.array_hash(state, precision);
81    }
82
83    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
84        array.dtype == other.dtype
85            && array.buffer.array_eq(&other.buffer, precision)
86            && array.validity.array_eq(&other.validity, precision)
87    }
88
89    fn nbuffers(_array: &ByteBoolArray) -> usize {
90        1
91    }
92
93    fn buffer(array: &ByteBoolArray, idx: usize) -> BufferHandle {
94        match idx {
95            0 => array.buffer().clone(),
96            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
97        }
98    }
99
100    fn buffer_name(_array: &ByteBoolArray, idx: usize) -> Option<String> {
101        match idx {
102            0 => Some("values".to_string()),
103            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
104        }
105    }
106
107    fn nchildren(array: &ByteBoolArray) -> usize {
108        validity_nchildren(array.validity())
109    }
110
111    fn child(array: &ByteBoolArray, idx: usize) -> ArrayRef {
112        match idx {
113            0 => validity_to_child(array.validity(), array.len())
114                .vortex_expect("ByteBoolArray validity child out of bounds"),
115            _ => vortex_panic!("ByteBoolArray child index {idx} out of bounds"),
116        }
117    }
118
119    fn child_name(_array: &ByteBoolArray, idx: usize) -> String {
120        match idx {
121            0 => "validity".to_string(),
122            _ => vortex_panic!("ByteBoolArray child_name index {idx} out of bounds"),
123        }
124    }
125
126    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
127        Ok(EmptyMetadata)
128    }
129
130    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
131        Ok(Some(vec![]))
132    }
133
134    fn deserialize(
135        _bytes: &[u8],
136        _dtype: &DType,
137        _len: usize,
138        _buffers: &[BufferHandle],
139        _session: &VortexSession,
140    ) -> VortexResult<Self::Metadata> {
141        Ok(EmptyMetadata)
142    }
143
144    fn build(
145        dtype: &DType,
146        len: usize,
147        _metadata: &Self::Metadata,
148        buffers: &[BufferHandle],
149        children: &dyn ArrayChildren,
150    ) -> VortexResult<ByteBoolArray> {
151        let validity = if children.is_empty() {
152            Validity::from(dtype.nullability())
153        } else if children.len() == 1 {
154            let validity = children.get(0, &Validity::DTYPE, len)?;
155            Validity::Array(validity)
156        } else {
157            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
158        };
159
160        if buffers.len() != 1 {
161            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
162        }
163        let buffer = buffers[0].clone();
164
165        Ok(ByteBoolArray::new(buffer, validity))
166    }
167
168    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
169        vortex_ensure!(
170            children.len() <= 1,
171            "ByteBoolArray expects at most 1 child (validity), got {}",
172            children.len()
173        );
174
175        array.validity = if children.is_empty() {
176            Validity::from(array.dtype.nullability())
177        } else {
178            Validity::Array(children.into_iter().next().vortex_expect("checked"))
179        };
180
181        Ok(())
182    }
183
184    fn reduce_parent(
185        array: &Array<Self>,
186        parent: &ArrayRef,
187        child_idx: usize,
188    ) -> VortexResult<Option<ArrayRef>> {
189        crate::rules::RULES.evaluate(array, parent, child_idx)
190    }
191
192    fn execute(array: Arc<Array<Self>>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
193        let boolean_buffer = BitBuffer::from(array.as_slice());
194        let validity = array.validity().clone();
195        Ok(ExecutionResult::done(
196            BoolArray::new(boolean_buffer, validity).into_array(),
197        ))
198    }
199
200    fn execute_parent(
201        array: &Array<Self>,
202        parent: &ArrayRef,
203        child_idx: usize,
204        ctx: &mut ExecutionCtx,
205    ) -> VortexResult<Option<ArrayRef>> {
206        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
207    }
208}
209
210#[derive(Clone, Debug)]
211pub struct ByteBoolArray {
212    dtype: DType,
213    buffer: BufferHandle,
214    validity: Validity,
215    stats_set: ArrayStats,
216}
217
218#[derive(Clone, Debug)]
219pub struct ByteBool;
220
221impl ByteBool {
222    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
223}
224
225impl ByteBoolArray {
226    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
227        let length = buffer.len();
228        if let Some(vlen) = validity.maybe_len()
229            && length != vlen
230        {
231            vortex_panic!(
232                "Buffer length ({}) does not match validity length ({})",
233                length,
234                vlen
235            );
236        }
237        Self {
238            dtype: DType::Bool(validity.nullability()),
239            buffer,
240            validity,
241            stats_set: Default::default(),
242        }
243    }
244
245    // TODO(ngates): deprecate construction from vec
246    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
247        let validity = validity.into();
248        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
249        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
250        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
251    }
252
253    pub fn buffer(&self) -> &BufferHandle {
254        &self.buffer
255    }
256
257    pub fn as_slice(&self) -> &[bool] {
258        // Safety: The internal buffer contains byte-sized bools
259        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
260    }
261}
262
263impl ValidityHelper for ByteBoolArray {
264    fn validity(&self) -> &Validity {
265        &self.validity
266    }
267}
268
269impl OperationsVTable<ByteBool> for ByteBool {
270    fn scalar_at(
271        array: &ByteBoolArray,
272        index: usize,
273        _ctx: &mut ExecutionCtx,
274    ) -> VortexResult<Scalar> {
275        Ok(Scalar::bool(
276            array.buffer.as_host()[index] == 1,
277            array.dtype().nullability(),
278        ))
279    }
280}
281
282impl From<Vec<bool>> for ByteBoolArray {
283    fn from(value: Vec<bool>) -> Self {
284        Self::from_vec(value, Validity::AllValid)
285    }
286}
287
288impl From<Vec<Option<bool>>> for ByteBoolArray {
289    fn from(value: Vec<Option<bool>>) -> Self {
290        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
291
292        // This doesn't reallocate, and the compiler even vectorizes it
293        let data = value.into_iter().map(Option::unwrap_or_default).collect();
294
295        Self::from_vec(data, validity)
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_validity_construction() {
305        let v = vec![true, false];
306        let v_len = v.len();
307
308        let arr = ByteBoolArray::from(v);
309        assert_eq!(v_len, arr.len());
310
311        for idx in 0..arr.len() {
312            assert!(arr.is_valid(idx).unwrap());
313        }
314
315        let v = vec![Some(true), None, Some(false)];
316        let arr = ByteBoolArray::from(v);
317        assert!(arr.is_valid(0).unwrap());
318        assert!(!arr.is_valid(1).unwrap());
319        assert!(arr.is_valid(2).unwrap());
320        assert_eq!(arr.len(), 3);
321
322        let v: Vec<Option<bool>> = vec![None, None];
323        let v_len = v.len();
324
325        let arr = ByteBoolArray::from(v);
326        assert_eq!(v_len, arr.len());
327
328        for idx in 0..arr.len() {
329            assert!(!arr.is_valid(idx).unwrap());
330        }
331        assert_eq!(arr.len(), 2);
332    }
333}