Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_array::ArrayEq;
8use vortex_array::ArrayHash;
9use vortex_array::ArrayRef;
10use vortex_array::EmptyMetadata;
11use vortex_array::ExecutionCtx;
12use vortex_array::ExecutionStep;
13use vortex_array::IntoArray;
14use vortex_array::Precision;
15use vortex_array::arrays::BoolArray;
16use vortex_array::buffer::BufferHandle;
17use vortex_array::dtype::DType;
18use vortex_array::scalar::Scalar;
19use vortex_array::serde::ArrayChildren;
20use vortex_array::stats::ArrayStats;
21use vortex_array::stats::StatsSetRef;
22use vortex_array::validity::Validity;
23use vortex_array::vtable;
24use vortex_array::vtable::ArrayId;
25use vortex_array::vtable::OperationsVTable;
26use vortex_array::vtable::VTable;
27use vortex_array::vtable::ValidityHelper;
28use vortex_array::vtable::ValidityVTableFromValidityHelper;
29use vortex_array::vtable::validity_nchildren;
30use vortex_array::vtable::validity_to_child;
31use vortex_buffer::BitBuffer;
32use vortex_buffer::ByteBuffer;
33use vortex_error::VortexExpect as _;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_session::VortexSession;
39
40use crate::kernel::PARENT_KERNELS;
41
42vtable!(ByteBool);
43
44impl VTable for ByteBoolVTable {
45    type Array = ByteBoolArray;
46
47    type Metadata = EmptyMetadata;
48    type OperationsVTable = Self;
49    type ValidityVTable = ValidityVTableFromValidityHelper;
50
51    fn id(_array: &Self::Array) -> ArrayId {
52        Self::ID
53    }
54
55    fn len(array: &ByteBoolArray) -> usize {
56        array.buffer.len()
57    }
58
59    fn dtype(array: &ByteBoolArray) -> &DType {
60        &array.dtype
61    }
62
63    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
64        array.stats_set.to_ref(array.as_ref())
65    }
66
67    fn array_hash<H: std::hash::Hasher>(
68        array: &ByteBoolArray,
69        state: &mut H,
70        precision: Precision,
71    ) {
72        array.dtype.hash(state);
73        array.buffer.array_hash(state, precision);
74        array.validity.array_hash(state, precision);
75    }
76
77    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
78        array.dtype == other.dtype
79            && array.buffer.array_eq(&other.buffer, precision)
80            && array.validity.array_eq(&other.validity, precision)
81    }
82
83    fn nbuffers(_array: &ByteBoolArray) -> usize {
84        1
85    }
86
87    fn buffer(array: &ByteBoolArray, idx: usize) -> BufferHandle {
88        match idx {
89            0 => array.buffer().clone(),
90            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
91        }
92    }
93
94    fn buffer_name(_array: &ByteBoolArray, idx: usize) -> Option<String> {
95        match idx {
96            0 => Some("values".to_string()),
97            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
98        }
99    }
100
101    fn nchildren(array: &ByteBoolArray) -> usize {
102        validity_nchildren(array.validity())
103    }
104
105    fn child(array: &ByteBoolArray, idx: usize) -> ArrayRef {
106        match idx {
107            0 => validity_to_child(array.validity(), array.len())
108                .vortex_expect("ByteBoolArray validity child out of bounds"),
109            _ => vortex_panic!("ByteBoolArray child index {idx} out of bounds"),
110        }
111    }
112
113    fn child_name(_array: &ByteBoolArray, idx: usize) -> String {
114        match idx {
115            0 => "validity".to_string(),
116            _ => vortex_panic!("ByteBoolArray child_name index {idx} out of bounds"),
117        }
118    }
119
120    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
121        Ok(EmptyMetadata)
122    }
123
124    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
125        Ok(Some(vec![]))
126    }
127
128    fn deserialize(
129        _bytes: &[u8],
130        _dtype: &DType,
131        _len: usize,
132        _buffers: &[BufferHandle],
133        _session: &VortexSession,
134    ) -> VortexResult<Self::Metadata> {
135        Ok(EmptyMetadata)
136    }
137
138    fn build(
139        dtype: &DType,
140        len: usize,
141        _metadata: &Self::Metadata,
142        buffers: &[BufferHandle],
143        children: &dyn ArrayChildren,
144    ) -> VortexResult<ByteBoolArray> {
145        let validity = if children.is_empty() {
146            Validity::from(dtype.nullability())
147        } else if children.len() == 1 {
148            let validity = children.get(0, &Validity::DTYPE, len)?;
149            Validity::Array(validity)
150        } else {
151            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
152        };
153
154        if buffers.len() != 1 {
155            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
156        }
157        let buffer = buffers[0].clone();
158
159        Ok(ByteBoolArray::new(buffer, validity))
160    }
161
162    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
163        vortex_ensure!(
164            children.len() <= 1,
165            "ByteBoolArray expects at most 1 child (validity), got {}",
166            children.len()
167        );
168
169        array.validity = if children.is_empty() {
170            Validity::from(array.dtype.nullability())
171        } else {
172            Validity::Array(children.into_iter().next().vortex_expect("checked"))
173        };
174
175        Ok(())
176    }
177
178    fn reduce_parent(
179        array: &Self::Array,
180        parent: &ArrayRef,
181        child_idx: usize,
182    ) -> VortexResult<Option<ArrayRef>> {
183        crate::rules::RULES.evaluate(array, parent, child_idx)
184    }
185
186    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
187        let boolean_buffer = BitBuffer::from(array.as_slice());
188        let validity = array.validity().clone();
189        Ok(ExecutionStep::Done(
190            BoolArray::new(boolean_buffer, validity).into_array(),
191        ))
192    }
193
194    fn execute_parent(
195        array: &Self::Array,
196        parent: &ArrayRef,
197        child_idx: usize,
198        ctx: &mut ExecutionCtx,
199    ) -> VortexResult<Option<ArrayRef>> {
200        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
201    }
202}
203
204#[derive(Clone, Debug)]
205pub struct ByteBoolArray {
206    dtype: DType,
207    buffer: BufferHandle,
208    validity: Validity,
209    stats_set: ArrayStats,
210}
211
212#[derive(Debug)]
213pub struct ByteBoolVTable;
214
215impl ByteBoolVTable {
216    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
217}
218
219impl ByteBoolArray {
220    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
221        let length = buffer.len();
222        if let Some(vlen) = validity.maybe_len()
223            && length != vlen
224        {
225            vortex_panic!(
226                "Buffer length ({}) does not match validity length ({})",
227                length,
228                vlen
229            );
230        }
231        Self {
232            dtype: DType::Bool(validity.nullability()),
233            buffer,
234            validity,
235            stats_set: Default::default(),
236        }
237    }
238
239    // TODO(ngates): deprecate construction from vec
240    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
241        let validity = validity.into();
242        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
243        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
244        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
245    }
246
247    pub fn buffer(&self) -> &BufferHandle {
248        &self.buffer
249    }
250
251    pub fn as_slice(&self) -> &[bool] {
252        // Safety: The internal buffer contains byte-sized bools
253        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
254    }
255}
256
257impl ValidityHelper for ByteBoolArray {
258    fn validity(&self) -> &Validity {
259        &self.validity
260    }
261}
262
263impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
264    fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
265        Ok(Scalar::bool(
266            array.buffer.as_host()[index] == 1,
267            array.dtype().nullability(),
268        ))
269    }
270}
271
272impl From<Vec<bool>> for ByteBoolArray {
273    fn from(value: Vec<bool>) -> Self {
274        Self::from_vec(value, Validity::AllValid)
275    }
276}
277
278impl From<Vec<Option<bool>>> for ByteBoolArray {
279    fn from(value: Vec<Option<bool>>) -> Self {
280        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
281
282        // This doesn't reallocate, and the compiler even vectorizes it
283        let data = value.into_iter().map(Option::unwrap_or_default).collect();
284
285        Self::from_vec(data, validity)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_validity_construction() {
295        let v = vec![true, false];
296        let v_len = v.len();
297
298        let arr = ByteBoolArray::from(v);
299        assert_eq!(v_len, arr.len());
300
301        for idx in 0..arr.len() {
302            assert!(arr.is_valid(idx).unwrap());
303        }
304
305        let v = vec![Some(true), None, Some(false)];
306        let arr = ByteBoolArray::from(v);
307        assert!(arr.is_valid(0).unwrap());
308        assert!(!arr.is_valid(1).unwrap());
309        assert!(arr.is_valid(2).unwrap());
310        assert_eq!(arr.len(), 3);
311
312        let v: Vec<Option<bool>> = vec![None, None];
313        let v_len = v.len();
314
315        let arr = ByteBoolArray::from(v);
316        assert_eq!(v_len, arr.len());
317
318        for idx in 0..arr.len() {
319            assert!(!arr.is_valid(idx).unwrap());
320        }
321        assert_eq!(arr.len(), 2);
322    }
323}