Skip to main content

vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_array::ArrayEq;
9use vortex_array::ArrayHash;
10use vortex_array::ArrayRef;
11use vortex_array::EmptyMetadata;
12use vortex_array::ExecutionCtx;
13use vortex_array::ExecutionResult;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::arrays::BoolArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::dtype::DType;
19use vortex_array::scalar::Scalar;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::OperationsVTable;
27use vortex_array::vtable::VTable;
28use vortex_array::vtable::ValidityHelper;
29use vortex_array::vtable::ValidityVTableFromValidityHelper;
30use vortex_array::vtable::validity_nchildren;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBuffer;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexExpect as _;
35use vortex_error::VortexResult;
36use vortex_error::vortex_bail;
37use vortex_error::vortex_ensure;
38use vortex_error::vortex_panic;
39use vortex_session::VortexSession;
40
41use crate::kernel::PARENT_KERNELS;
42
43vtable!(ByteBool);
44
45impl VTable for ByteBool {
46    type Array = ByteBoolArray;
47
48    type Metadata = EmptyMetadata;
49    type OperationsVTable = Self;
50    type ValidityVTable = ValidityVTableFromValidityHelper;
51
52    fn vtable(_array: &Self::Array) -> &Self {
53        &ByteBool
54    }
55
56    fn id(&self) -> ArrayId {
57        Self::ID
58    }
59
60    fn len(array: &ByteBoolArray) -> usize {
61        array.buffer.len()
62    }
63
64    fn dtype(array: &ByteBoolArray) -> &DType {
65        &array.dtype
66    }
67
68    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
69        array.stats_set.to_ref(array.as_ref())
70    }
71
72    fn array_hash<H: std::hash::Hasher>(
73        array: &ByteBoolArray,
74        state: &mut H,
75        precision: Precision,
76    ) {
77        array.dtype.hash(state);
78        array.buffer.array_hash(state, precision);
79        array.validity.array_hash(state, precision);
80    }
81
82    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
83        array.dtype == other.dtype
84            && array.buffer.array_eq(&other.buffer, precision)
85            && array.validity.array_eq(&other.validity, precision)
86    }
87
88    fn nbuffers(_array: &ByteBoolArray) -> usize {
89        1
90    }
91
92    fn buffer(array: &ByteBoolArray, idx: usize) -> BufferHandle {
93        match idx {
94            0 => array.buffer().clone(),
95            _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
96        }
97    }
98
99    fn buffer_name(_array: &ByteBoolArray, idx: usize) -> Option<String> {
100        match idx {
101            0 => Some("values".to_string()),
102            _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
103        }
104    }
105
106    fn nchildren(array: &ByteBoolArray) -> usize {
107        validity_nchildren(array.validity())
108    }
109
110    fn child(array: &ByteBoolArray, idx: usize) -> ArrayRef {
111        match idx {
112            0 => validity_to_child(array.validity(), array.len())
113                .vortex_expect("ByteBoolArray validity child out of bounds"),
114            _ => vortex_panic!("ByteBoolArray child index {idx} out of bounds"),
115        }
116    }
117
118    fn child_name(_array: &ByteBoolArray, idx: usize) -> String {
119        match idx {
120            0 => "validity".to_string(),
121            _ => vortex_panic!("ByteBoolArray child_name index {idx} out of bounds"),
122        }
123    }
124
125    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
126        Ok(EmptyMetadata)
127    }
128
129    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
130        Ok(Some(vec![]))
131    }
132
133    fn deserialize(
134        _bytes: &[u8],
135        _dtype: &DType,
136        _len: usize,
137        _buffers: &[BufferHandle],
138        _session: &VortexSession,
139    ) -> VortexResult<Self::Metadata> {
140        Ok(EmptyMetadata)
141    }
142
143    fn build(
144        dtype: &DType,
145        len: usize,
146        _metadata: &Self::Metadata,
147        buffers: &[BufferHandle],
148        children: &dyn ArrayChildren,
149    ) -> VortexResult<ByteBoolArray> {
150        let validity = if children.is_empty() {
151            Validity::from(dtype.nullability())
152        } else if children.len() == 1 {
153            let validity = children.get(0, &Validity::DTYPE, len)?;
154            Validity::Array(validity)
155        } else {
156            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
157        };
158
159        if buffers.len() != 1 {
160            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
161        }
162        let buffer = buffers[0].clone();
163
164        Ok(ByteBoolArray::new(buffer, validity))
165    }
166
167    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
168        vortex_ensure!(
169            children.len() <= 1,
170            "ByteBoolArray expects at most 1 child (validity), got {}",
171            children.len()
172        );
173
174        array.validity = if children.is_empty() {
175            Validity::from(array.dtype.nullability())
176        } else {
177            Validity::Array(children.into_iter().next().vortex_expect("checked"))
178        };
179
180        Ok(())
181    }
182
183    fn reduce_parent(
184        array: &Self::Array,
185        parent: &ArrayRef,
186        child_idx: usize,
187    ) -> VortexResult<Option<ArrayRef>> {
188        crate::rules::RULES.evaluate(array, parent, child_idx)
189    }
190
191    fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
192        let boolean_buffer = BitBuffer::from(array.as_slice());
193        let validity = array.validity().clone();
194        Ok(ExecutionResult::done(
195            BoolArray::new(boolean_buffer, validity).into_array(),
196        ))
197    }
198
199    fn execute_parent(
200        array: &Self::Array,
201        parent: &ArrayRef,
202        child_idx: usize,
203        ctx: &mut ExecutionCtx,
204    ) -> VortexResult<Option<ArrayRef>> {
205        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
206    }
207}
208
209#[derive(Clone, Debug)]
210pub struct ByteBoolArray {
211    dtype: DType,
212    buffer: BufferHandle,
213    validity: Validity,
214    stats_set: ArrayStats,
215}
216
217#[derive(Clone, Debug)]
218pub struct ByteBool;
219
220impl ByteBool {
221    pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
222}
223
224impl ByteBoolArray {
225    pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
226        let length = buffer.len();
227        if let Some(vlen) = validity.maybe_len()
228            && length != vlen
229        {
230            vortex_panic!(
231                "Buffer length ({}) does not match validity length ({})",
232                length,
233                vlen
234            );
235        }
236        Self {
237            dtype: DType::Bool(validity.nullability()),
238            buffer,
239            validity,
240            stats_set: Default::default(),
241        }
242    }
243
244    // TODO(ngates): deprecate construction from vec
245    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
246        let validity = validity.into();
247        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
248        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
249        Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
250    }
251
252    pub fn buffer(&self) -> &BufferHandle {
253        &self.buffer
254    }
255
256    pub fn as_slice(&self) -> &[bool] {
257        // Safety: The internal buffer contains byte-sized bools
258        unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
259    }
260}
261
262impl ValidityHelper for ByteBoolArray {
263    fn validity(&self) -> &Validity {
264        &self.validity
265    }
266}
267
268impl OperationsVTable<ByteBool> for ByteBool {
269    fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
270        Ok(Scalar::bool(
271            array.buffer.as_host()[index] == 1,
272            array.dtype().nullability(),
273        ))
274    }
275}
276
277impl From<Vec<bool>> for ByteBoolArray {
278    fn from(value: Vec<bool>) -> Self {
279        Self::from_vec(value, Validity::AllValid)
280    }
281}
282
283impl From<Vec<Option<bool>>> for ByteBoolArray {
284    fn from(value: Vec<Option<bool>>) -> Self {
285        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
286
287        // This doesn't reallocate, and the compiler even vectorizes it
288        let data = value.into_iter().map(Option::unwrap_or_default).collect();
289
290        Self::from_vec(data, validity)
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_validity_construction() {
300        let v = vec![true, false];
301        let v_len = v.len();
302
303        let arr = ByteBoolArray::from(v);
304        assert_eq!(v_len, arr.len());
305
306        for idx in 0..arr.len() {
307            assert!(arr.is_valid(idx).unwrap());
308        }
309
310        let v = vec![Some(true), None, Some(false)];
311        let arr = ByteBoolArray::from(v);
312        assert!(arr.is_valid(0).unwrap());
313        assert!(!arr.is_valid(1).unwrap());
314        assert!(arr.is_valid(2).unwrap());
315        assert_eq!(arr.len(), 3);
316
317        let v: Vec<Option<bool>> = vec![None, None];
318        let v_len = v.len();
319
320        let arr = ByteBoolArray::from(v);
321        assert_eq!(v_len, arr.len());
322
323        for idx in 0..arr.len() {
324            assert!(!arr.is_valid(idx).unwrap());
325        }
326        assert_eq!(arr.len(), 2);
327    }
328}