vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Range;
7
8use vortex_array::ArrayBufferVisitor;
9use vortex_array::ArrayChildVisitor;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::Canonical;
14use vortex_array::EmptyMetadata;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::arrays::BoolArray;
18use vortex_array::serde::ArrayChildren;
19use vortex_array::stats::ArrayStats;
20use vortex_array::stats::StatsSetRef;
21use vortex_array::validity::Validity;
22use vortex_array::vtable;
23use vortex_array::vtable::ArrayId;
24use vortex_array::vtable::ArrayVTable;
25use vortex_array::vtable::ArrayVTableExt;
26use vortex_array::vtable::BaseArrayVTable;
27use vortex_array::vtable::CanonicalVTable;
28use vortex_array::vtable::NotSupported;
29use vortex_array::vtable::OperationsVTable;
30use vortex_array::vtable::VTable;
31use vortex_array::vtable::ValidityHelper;
32use vortex_array::vtable::ValidityVTableFromValidityHelper;
33use vortex_array::vtable::VisitorVTable;
34use vortex_buffer::BitBuffer;
35use vortex_buffer::BufferHandle;
36use vortex_buffer::ByteBuffer;
37use vortex_dtype::DType;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_panic;
41use vortex_scalar::Scalar;
42
43vtable!(ByteBool);
44
45impl VTable for ByteBoolVTable {
46    type Array = ByteBoolArray;
47
48    type Metadata = EmptyMetadata;
49
50    type ArrayVTable = Self;
51    type CanonicalVTable = Self;
52    type OperationsVTable = Self;
53    type ValidityVTable = ValidityVTableFromValidityHelper;
54    type VisitorVTable = Self;
55    type ComputeVTable = NotSupported;
56    type EncodeVTable = NotSupported;
57
58    fn id(&self) -> ArrayId {
59        ArrayId::new_ref("vortex.bytebool")
60    }
61
62    fn encoding(_array: &Self::Array) -> ArrayVTable {
63        ByteBoolVTable.as_vtable()
64    }
65
66    fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
67        Ok(EmptyMetadata)
68    }
69
70    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
71        Ok(Some(vec![]))
72    }
73
74    fn deserialize(_buffer: &[u8]) -> VortexResult<Self::Metadata> {
75        Ok(EmptyMetadata)
76    }
77
78    fn build(
79        &self,
80        dtype: &DType,
81        len: usize,
82        _metadata: &Self::Metadata,
83        buffers: &[BufferHandle],
84        children: &dyn ArrayChildren,
85    ) -> VortexResult<ByteBoolArray> {
86        let validity = if children.is_empty() {
87            Validity::from(dtype.nullability())
88        } else if children.len() == 1 {
89            let validity = children.get(0, &Validity::DTYPE, len)?;
90            Validity::Array(validity)
91        } else {
92            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
93        };
94
95        if buffers.len() != 1 {
96            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
97        }
98        let buffer = buffers[0].clone().try_to_bytes()?;
99
100        Ok(ByteBoolArray::new(buffer, validity))
101    }
102}
103
104#[derive(Clone, Debug)]
105pub struct ByteBoolArray {
106    dtype: DType,
107    buffer: ByteBuffer,
108    validity: Validity,
109    stats_set: ArrayStats,
110}
111
112#[derive(Debug)]
113pub struct ByteBoolVTable;
114
115impl ByteBoolArray {
116    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
117        let length = buffer.len();
118        if let Some(vlen) = validity.maybe_len()
119            && length != vlen
120        {
121            vortex_panic!(
122                "Buffer length ({}) does not match validity length ({})",
123                length,
124                vlen
125            );
126        }
127        Self {
128            dtype: DType::Bool(validity.nullability()),
129            buffer,
130            validity,
131            stats_set: Default::default(),
132        }
133    }
134
135    // TODO(ngates): deprecate construction from vec
136    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
137        let validity = validity.into();
138        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
139        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
140        Self::new(ByteBuffer::from(data), validity)
141    }
142
143    pub fn buffer(&self) -> &ByteBuffer {
144        &self.buffer
145    }
146
147    pub fn as_slice(&self) -> &[bool] {
148        // Safety: The internal buffer contains byte-sized bools
149        unsafe { std::mem::transmute(self.buffer().as_slice()) }
150    }
151}
152
153impl ValidityHelper for ByteBoolArray {
154    fn validity(&self) -> &Validity {
155        &self.validity
156    }
157}
158
159impl BaseArrayVTable<ByteBoolVTable> for ByteBoolVTable {
160    fn len(array: &ByteBoolArray) -> usize {
161        array.buffer.len()
162    }
163
164    fn dtype(array: &ByteBoolArray) -> &DType {
165        &array.dtype
166    }
167
168    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
169        array.stats_set.to_ref(array.as_ref())
170    }
171
172    fn array_hash<H: std::hash::Hasher>(
173        array: &ByteBoolArray,
174        state: &mut H,
175        precision: Precision,
176    ) {
177        array.dtype.hash(state);
178        array.buffer.array_hash(state, precision);
179        array.validity.array_hash(state, precision);
180    }
181
182    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
183        array.dtype == other.dtype
184            && array.buffer.array_eq(&other.buffer, precision)
185            && array.validity.array_eq(&other.validity, precision)
186    }
187}
188
189impl CanonicalVTable<ByteBoolVTable> for ByteBoolVTable {
190    fn canonicalize(array: &ByteBoolArray) -> Canonical {
191        let boolean_buffer = BitBuffer::from(array.as_slice());
192        let validity = array.validity().clone();
193        Canonical::Bool(BoolArray::from_bit_buffer(boolean_buffer, validity))
194    }
195}
196
197impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
198    fn slice(array: &ByteBoolArray, range: Range<usize>) -> ArrayRef {
199        ByteBoolArray::new(
200            array.buffer().slice(range.clone()),
201            array.validity().slice(range),
202        )
203        .into_array()
204    }
205
206    fn scalar_at(array: &ByteBoolArray, index: usize) -> Scalar {
207        Scalar::bool(array.buffer()[index] == 1, array.dtype().nullability())
208    }
209}
210
211impl VisitorVTable<ByteBoolVTable> for ByteBoolVTable {
212    fn visit_buffers(array: &ByteBoolArray, visitor: &mut dyn ArrayBufferVisitor) {
213        visitor.visit_buffer(array.buffer());
214    }
215
216    fn visit_children(array: &ByteBoolArray, visitor: &mut dyn ArrayChildVisitor) {
217        visitor.visit_validity(array.validity(), array.len());
218    }
219}
220
221impl From<Vec<bool>> for ByteBoolArray {
222    fn from(value: Vec<bool>) -> Self {
223        Self::from_vec(value, Validity::AllValid)
224    }
225}
226
227impl From<Vec<Option<bool>>> for ByteBoolArray {
228    fn from(value: Vec<Option<bool>>) -> Self {
229        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
230
231        // This doesn't reallocate, and the compiler even vectorizes it
232        let data = value.into_iter().map(Option::unwrap_or_default).collect();
233
234        Self::from_vec(data, validity)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    // #[cfg_attr(miri, ignore)]
243    // #[test]
244    // fn test_bytebool_metadata() {
245    //     check_metadata(
246    //         "bytebool.metadata",
247    //         SerdeMetadata(ByteBoolMetadata {
248    //             validity: ValidityMetadata::AllValid,
249    //         }),
250    //     );
251    // }
252
253    #[test]
254    fn test_validity_construction() {
255        let v = vec![true, false];
256        let v_len = v.len();
257
258        let arr = ByteBoolArray::from(v);
259        assert_eq!(v_len, arr.len());
260
261        for idx in 0..arr.len() {
262            assert!(arr.is_valid(idx));
263        }
264
265        let v = vec![Some(true), None, Some(false)];
266        let arr = ByteBoolArray::from(v);
267        assert!(arr.is_valid(0));
268        assert!(!arr.is_valid(1));
269        assert!(arr.is_valid(2));
270        assert_eq!(arr.len(), 3);
271
272        let v: Vec<Option<bool>> = vec![None, None];
273        let v_len = v.len();
274
275        let arr = ByteBoolArray::from(v);
276        assert_eq!(v_len, arr.len());
277
278        for idx in 0..arr.len() {
279            assert!(!arr.is_valid(idx));
280        }
281        assert_eq!(arr.len(), 2);
282    }
283}