vortex_bytebool/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::arrays::BoolArray;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::validity::Validity;
10use vortex_array::vtable::{
11    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
12    ValidityVTableFromValidityHelper,
13};
14use vortex_array::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
15use vortex_buffer::ByteBuffer;
16use vortex_dtype::DType;
17use vortex_error::{VortexResult, vortex_panic};
18use vortex_scalar::Scalar;
19
20vtable!(ByteBool);
21
22impl VTable for ByteBoolVTable {
23    type Array = ByteBoolArray;
24    type Encoding = ByteBoolEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = ValidityVTableFromValidityHelper;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = NotSupported;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.bytebool")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(ByteBoolEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct ByteBoolArray {
46    dtype: DType,
47    buffer: ByteBuffer,
48    validity: Validity,
49    stats_set: ArrayStats,
50}
51
52#[derive(Clone, Debug)]
53pub struct ByteBoolEncoding;
54
55impl ByteBoolArray {
56    pub fn new(buffer: ByteBuffer, validity: Validity) -> Self {
57        let length = buffer.len();
58        if let Some(vlen) = validity.maybe_len() {
59            if length != vlen {
60                vortex_panic!(
61                    "Buffer length ({}) does not match validity length ({})",
62                    length,
63                    vlen
64                );
65            }
66        }
67        Self {
68            dtype: DType::Bool(validity.nullability()),
69            buffer,
70            validity,
71            stats_set: Default::default(),
72        }
73    }
74
75    // TODO(ngates): deprecate construction from vec
76    pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
77        let validity = validity.into();
78        // SAFETY: we are transmuting a Vec<bool> into a Vec<u8>
79        let data: Vec<u8> = unsafe { std::mem::transmute(data) };
80        Self::new(ByteBuffer::from(data), validity)
81    }
82
83    pub fn buffer(&self) -> &ByteBuffer {
84        &self.buffer
85    }
86
87    pub fn as_slice(&self) -> &[bool] {
88        // Safety: The internal buffer contains byte-sized bools
89        unsafe { std::mem::transmute(self.buffer().as_slice()) }
90    }
91}
92
93impl ValidityHelper for ByteBoolArray {
94    fn validity(&self) -> &Validity {
95        &self.validity
96    }
97}
98
99impl ArrayVTable<ByteBoolVTable> for ByteBoolVTable {
100    fn len(array: &ByteBoolArray) -> usize {
101        array.buffer.len()
102    }
103
104    fn dtype(array: &ByteBoolArray) -> &DType {
105        &array.dtype
106    }
107
108    fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
109        array.stats_set.to_ref(array.as_ref())
110    }
111}
112
113impl CanonicalVTable<ByteBoolVTable> for ByteBoolVTable {
114    fn canonicalize(array: &ByteBoolArray) -> VortexResult<Canonical> {
115        let boolean_buffer = BooleanBuffer::from(array.as_slice());
116        let validity = array.validity().clone();
117        Ok(Canonical::Bool(BoolArray::new(boolean_buffer, validity)))
118    }
119}
120
121impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
122    fn slice(array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
123        Ok(ByteBoolArray::new(
124            array.buffer().slice(start..stop),
125            array.validity().slice(start, stop)?,
126        )
127        .into_array())
128    }
129
130    fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
131        Ok(Scalar::bool(
132            array.buffer()[index] == 1,
133            array.dtype().nullability(),
134        ))
135    }
136}
137
138impl From<Vec<bool>> for ByteBoolArray {
139    fn from(value: Vec<bool>) -> Self {
140        Self::from_vec(value, Validity::AllValid)
141    }
142}
143
144impl From<Vec<Option<bool>>> for ByteBoolArray {
145    fn from(value: Vec<Option<bool>>) -> Self {
146        let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
147
148        // This doesn't reallocate, and the compiler even vectorizes it
149        let data = value.into_iter().map(Option::unwrap_or_default).collect();
150
151        Self::from_vec(data, validity)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    // #[cfg_attr(miri, ignore)]
160    // #[test]
161    // fn test_bytebool_metadata() {
162    //     check_metadata(
163    //         "bytebool.metadata",
164    //         SerdeMetadata(ByteBoolMetadata {
165    //             validity: ValidityMetadata::AllValid,
166    //         }),
167    //     );
168    // }
169
170    #[test]
171    fn test_validity_construction() {
172        let v = vec![true, false];
173        let v_len = v.len();
174
175        let arr = ByteBoolArray::from(v);
176        assert_eq!(v_len, arr.len());
177
178        for idx in 0..arr.len() {
179            assert!(arr.is_valid(idx).unwrap());
180        }
181
182        let v = vec![Some(true), None, Some(false)];
183        let arr = ByteBoolArray::from(v);
184        assert!(arr.is_valid(0).unwrap());
185        assert!(!arr.is_valid(1).unwrap());
186        assert!(arr.is_valid(2).unwrap());
187        assert_eq!(arr.len(), 3);
188
189        let v: Vec<Option<bool>> = vec![None, None];
190        let v_len = v.len();
191
192        let arr = ByteBoolArray::from(v);
193        assert_eq!(v_len, arr.len());
194
195        for idx in 0..arr.len() {
196            assert!(!arr.is_valid(idx).unwrap());
197        }
198        assert_eq!(arr.len(), 2);
199    }
200}