vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::Range;
7
8use vortex_array::arrays::BoolArray;
9use vortex_array::stats::{ArrayStats, StatsSetRef};
10use vortex_array::validity::Validity;
11use vortex_array::vtable::{
12    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
13    ValidityVTableFromValidityHelper,
14};
15use vortex_array::{
16    ArrayEq, ArrayHash, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, Precision, vtable,
17};
18use vortex_buffer::{BitBuffer, ByteBuffer};
19use vortex_dtype::DType;
20use vortex_error::vortex_panic;
21use vortex_scalar::Scalar;
22
23vtable!(ByteBool);
24
25impl VTable for ByteBoolVTable {
26    type Array = ByteBoolArray;
27    type Encoding = ByteBoolEncoding;
28
29    type ArrayVTable = Self;
30    type CanonicalVTable = Self;
31    type OperationsVTable = Self;
32    type ValidityVTable = ValidityVTableFromValidityHelper;
33    type VisitorVTable = Self;
34    type ComputeVTable = NotSupported;
35    type EncodeVTable = NotSupported;
36    type SerdeVTable = Self;
37    type OperatorVTable = NotSupported;
38
39    fn id(_encoding: &Self::Encoding) -> EncodingId {
40        EncodingId::new_ref("vortex.bytebool")
41    }
42
43    fn encoding(_array: &Self::Array) -> EncodingRef {
44        EncodingRef::new_ref(ByteBoolEncoding.as_ref())
45    }
46}
47
48#[derive(Clone, Debug)]
49pub struct ByteBoolArray {
50    dtype: DType,
51    buffer: ByteBuffer,
52    validity: Validity,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct ByteBoolEncoding;
58
59impl ByteBoolArray {
60    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
61        let length = buffer.len();
62        if let Some(vlen) = validity.maybe_len()
63            && length != vlen
64        {
65            vortex_panic!(
66                "Buffer length ({}) does not match validity length ({})",
67                length,
68                vlen
69            );
70        }
71        Self {
72            dtype: DType::Bool(validity.nullability()),
73            buffer,
74            validity,
75            stats_set: Default::default(),
76        }
77    }
78
79    // TODO(ngates): deprecate construction from vec
80    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
81        let validity = validity.into();
82        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
83        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
84        Self::new(ByteBuffer::from(data), validity)
85    }
86
87    pub fn buffer(&self) -> &ByteBuffer {
88        &self.buffer
89    }
90
91    pub fn as_slice(&self) -> &[bool] {
92        // Safety: The internal buffer contains byte-sized bools
93        unsafe { std::mem::transmute(self.buffer().as_slice()) }
94    }
95}
96
97impl ValidityHelper for ByteBoolArray {
98    fn validity(&self) -> &Validity {
99        &self.validity
100    }
101}
102
103impl ArrayVTable<ByteBoolVTable> for ByteBoolVTable {
104    fn len(array: &ByteBoolArray) -> usize {
105        array.buffer.len()
106    }
107
108    fn dtype(array: &ByteBoolArray) -> &DType {
109        &array.dtype
110    }
111
112    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
113        array.stats_set.to_ref(array.as_ref())
114    }
115
116    fn array_hash<H: std::hash::Hasher>(
117        array: &ByteBoolArray,
118        state: &mut H,
119        precision: Precision,
120    ) {
121        array.dtype.hash(state);
122        array.buffer.array_hash(state, precision);
123        array.validity.array_hash(state, precision);
124    }
125
126    fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
127        array.dtype == other.dtype
128            && array.buffer.array_eq(&other.buffer, precision)
129            && array.validity.array_eq(&other.validity, precision)
130    }
131}
132
133impl CanonicalVTable<ByteBoolVTable> for ByteBoolVTable {
134    fn canonicalize(array: &ByteBoolArray) -> Canonical {
135        let boolean_buffer = BitBuffer::from(array.as_slice());
136        let validity = array.validity().clone();
137        Canonical::Bool(BoolArray::from_bit_buffer(boolean_buffer, validity))
138    }
139}
140
141impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
142    fn slice(array: &ByteBoolArray, range: Range<usize>) -> ArrayRef {
143        ByteBoolArray::new(
144            array.buffer().slice(range.clone()),
145            array.validity().slice(range),
146        )
147        .into_array()
148    }
149
150    fn scalar_at(array: &ByteBoolArray, index: usize) -> Scalar {
151        Scalar::bool(array.buffer()[index] == 1, array.dtype().nullability())
152    }
153}
154
155impl From<Vec<bool>> for ByteBoolArray {
156    fn from(value: Vec<bool>) -> Self {
157        Self::from_vec(value, Validity::AllValid)
158    }
159}
160
161impl From<Vec<Option<bool>>> for ByteBoolArray {
162    fn from(value: Vec<Option<bool>>) -> Self {
163        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
164
165        // This doesn't reallocate, and the compiler even vectorizes it
166        let data = value.into_iter().map(Option::unwrap_or_default).collect();
167
168        Self::from_vec(data, validity)
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    // #[cfg_attr(miri, ignore)]
177    // #[test]
178    // fn test_bytebool_metadata() {
179    //     check_metadata(
180    //         "bytebool.metadata",
181    //         SerdeMetadata(ByteBoolMetadata {
182    //             validity: ValidityMetadata::AllValid,
183    //         }),
184    //     );
185    // }
186
187    #[test]
188    fn test_validity_construction() {
189        let v = vec![true, false];
190        let v_len = v.len();
191
192        let arr = ByteBoolArray::from(v);
193        assert_eq!(v_len, arr.len());
194
195        for idx in 0..arr.len() {
196            assert!(arr.is_valid(idx));
197        }
198
199        let v = vec![Some(true), None, Some(false)];
200        let arr = ByteBoolArray::from(v);
201        assert!(arr.is_valid(0));
202        assert!(!arr.is_valid(1));
203        assert!(arr.is_valid(2));
204        assert_eq!(arr.len(), 3);
205
206        let v: Vec<Option<bool>> = vec![None, None];
207        let v_len = v.len();
208
209        let arr = ByteBoolArray::from(v);
210        assert_eq!(v_len, arr.len());
211
212        for idx in 0..arr.len() {
213            assert!(!arr.is_valid(idx));
214        }
215        assert_eq!(arr.len(), 2);
216    }
217}