1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_array::ArrayEq;
9use vortex_array::ArrayHash;
10use vortex_array::ArrayRef;
11use vortex_array::EmptyMetadata;
12use vortex_array::ExecutionCtx;
13use vortex_array::ExecutionResult;
14use vortex_array::IntoArray;
15use vortex_array::Precision;
16use vortex_array::arrays::BoolArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::dtype::DType;
19use vortex_array::scalar::Scalar;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::OperationsVTable;
27use vortex_array::vtable::VTable;
28use vortex_array::vtable::ValidityHelper;
29use vortex_array::vtable::ValidityVTableFromValidityHelper;
30use vortex_array::vtable::validity_nchildren;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBuffer;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexExpect as _;
35use vortex_error::VortexResult;
36use vortex_error::vortex_bail;
37use vortex_error::vortex_ensure;
38use vortex_error::vortex_panic;
39use vortex_session::VortexSession;
40
41use crate::kernel::PARENT_KERNELS;
42
43vtable!(ByteBool);
44
45impl VTable for ByteBool {
46 type Array = ByteBoolArray;
47
48 type Metadata = EmptyMetadata;
49 type OperationsVTable = Self;
50 type ValidityVTable = ValidityVTableFromValidityHelper;
51
52 fn vtable(_array: &Self::Array) -> &Self {
53 &ByteBool
54 }
55
56 fn id(&self) -> ArrayId {
57 Self::ID
58 }
59
60 fn len(array: &ByteBoolArray) -> usize {
61 array.buffer.len()
62 }
63
64 fn dtype(array: &ByteBoolArray) -> &DType {
65 &array.dtype
66 }
67
68 fn stats(array: &ByteBoolArray) -> StatsSetRef<'_> {
69 array.stats_set.to_ref(array.as_ref())
70 }
71
72 fn array_hash<H: std::hash::Hasher>(
73 array: &ByteBoolArray,
74 state: &mut H,
75 precision: Precision,
76 ) {
77 array.dtype.hash(state);
78 array.buffer.array_hash(state, precision);
79 array.validity.array_hash(state, precision);
80 }
81
82 fn array_eq(array: &ByteBoolArray, other: &ByteBoolArray, precision: Precision) -> bool {
83 array.dtype == other.dtype
84 && array.buffer.array_eq(&other.buffer, precision)
85 && array.validity.array_eq(&other.validity, precision)
86 }
87
88 fn nbuffers(_array: &ByteBoolArray) -> usize {
89 1
90 }
91
92 fn buffer(array: &ByteBoolArray, idx: usize) -> BufferHandle {
93 match idx {
94 0 => array.buffer().clone(),
95 _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
96 }
97 }
98
99 fn buffer_name(_array: &ByteBoolArray, idx: usize) -> Option<String> {
100 match idx {
101 0 => Some("values".to_string()),
102 _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
103 }
104 }
105
106 fn nchildren(array: &ByteBoolArray) -> usize {
107 validity_nchildren(array.validity())
108 }
109
110 fn child(array: &ByteBoolArray, idx: usize) -> ArrayRef {
111 match idx {
112 0 => validity_to_child(array.validity(), array.len())
113 .vortex_expect("ByteBoolArray validity child out of bounds"),
114 _ => vortex_panic!("ByteBoolArray child index {idx} out of bounds"),
115 }
116 }
117
118 fn child_name(_array: &ByteBoolArray, idx: usize) -> String {
119 match idx {
120 0 => "validity".to_string(),
121 _ => vortex_panic!("ByteBoolArray child_name index {idx} out of bounds"),
122 }
123 }
124
125 fn metadata(_array: &ByteBoolArray) -> VortexResult<Self::Metadata> {
126 Ok(EmptyMetadata)
127 }
128
129 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
130 Ok(Some(vec![]))
131 }
132
133 fn deserialize(
134 _bytes: &[u8],
135 _dtype: &DType,
136 _len: usize,
137 _buffers: &[BufferHandle],
138 _session: &VortexSession,
139 ) -> VortexResult<Self::Metadata> {
140 Ok(EmptyMetadata)
141 }
142
143 fn build(
144 dtype: &DType,
145 len: usize,
146 _metadata: &Self::Metadata,
147 buffers: &[BufferHandle],
148 children: &dyn ArrayChildren,
149 ) -> VortexResult<ByteBoolArray> {
150 let validity = if children.is_empty() {
151 Validity::from(dtype.nullability())
152 } else if children.len() == 1 {
153 let validity = children.get(0, &Validity::DTYPE, len)?;
154 Validity::Array(validity)
155 } else {
156 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
157 };
158
159 if buffers.len() != 1 {
160 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
161 }
162 let buffer = buffers[0].clone();
163
164 Ok(ByteBoolArray::new(buffer, validity))
165 }
166
167 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
168 vortex_ensure!(
169 children.len() <= 1,
170 "ByteBoolArray expects at most 1 child (validity), got {}",
171 children.len()
172 );
173
174 array.validity = if children.is_empty() {
175 Validity::from(array.dtype.nullability())
176 } else {
177 Validity::Array(children.into_iter().next().vortex_expect("checked"))
178 };
179
180 Ok(())
181 }
182
183 fn reduce_parent(
184 array: &Self::Array,
185 parent: &ArrayRef,
186 child_idx: usize,
187 ) -> VortexResult<Option<ArrayRef>> {
188 crate::rules::RULES.evaluate(array, parent, child_idx)
189 }
190
191 fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
192 let boolean_buffer = BitBuffer::from(array.as_slice());
193 let validity = array.validity().clone();
194 Ok(ExecutionResult::done(
195 BoolArray::new(boolean_buffer, validity).into_array(),
196 ))
197 }
198
199 fn execute_parent(
200 array: &Self::Array,
201 parent: &ArrayRef,
202 child_idx: usize,
203 ctx: &mut ExecutionCtx,
204 ) -> VortexResult<Option<ArrayRef>> {
205 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
206 }
207}
208
209#[derive(Clone, Debug)]
210pub struct ByteBoolArray {
211 dtype: DType,
212 buffer: BufferHandle,
213 validity: Validity,
214 stats_set: ArrayStats,
215}
216
217#[derive(Clone, Debug)]
218pub struct ByteBool;
219
220impl ByteBool {
221 pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
222}
223
224impl ByteBoolArray {
225 pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
226 let length = buffer.len();
227 if let Some(vlen) = validity.maybe_len()
228 && length != vlen
229 {
230 vortex_panic!(
231 "Buffer length ({}) does not match validity length ({})",
232 length,
233 vlen
234 );
235 }
236 Self {
237 dtype: DType::Bool(validity.nullability()),
238 buffer,
239 validity,
240 stats_set: Default::default(),
241 }
242 }
243
244 pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
246 let validity = validity.into();
247 let data: Vec<u8> = unsafe { std::mem::transmute(data) };
249 Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
250 }
251
252 pub fn buffer(&self) -> &BufferHandle {
253 &self.buffer
254 }
255
256 pub fn as_slice(&self) -> &[bool] {
257 unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
259 }
260}
261
262impl ValidityHelper for ByteBoolArray {
263 fn validity(&self) -> &Validity {
264 &self.validity
265 }
266}
267
268impl OperationsVTable<ByteBool> for ByteBool {
269 fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
270 Ok(Scalar::bool(
271 array.buffer.as_host()[index] == 1,
272 array.dtype().nullability(),
273 ))
274 }
275}
276
277impl From<Vec<bool>> for ByteBoolArray {
278 fn from(value: Vec<bool>) -> Self {
279 Self::from_vec(value, Validity::AllValid)
280 }
281}
282
283impl From<Vec<Option<bool>>> for ByteBoolArray {
284 fn from(value: Vec<Option<bool>>) -> Self {
285 let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
286
287 let data = value.into_iter().map(Option::unwrap_or_default).collect();
289
290 Self::from_vec(data, validity)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_validity_construction() {
300 let v = vec![true, false];
301 let v_len = v.len();
302
303 let arr = ByteBoolArray::from(v);
304 assert_eq!(v_len, arr.len());
305
306 for idx in 0..arr.len() {
307 assert!(arr.is_valid(idx).unwrap());
308 }
309
310 let v = vec![Some(true), None, Some(false)];
311 let arr = ByteBoolArray::from(v);
312 assert!(arr.is_valid(0).unwrap());
313 assert!(!arr.is_valid(1).unwrap());
314 assert!(arr.is_valid(2).unwrap());
315 assert_eq!(arr.len(), 3);
316
317 let v: Vec<Option<bool>> = vec![None, None];
318 let v_len = v.len();
319
320 let arr = ByteBoolArray::from(v);
321 assert_eq!(v_len, arr.len());
322
323 for idx in 0..arr.len() {
324 assert!(!arr.is_valid(idx).unwrap());
325 }
326 assert_eq!(arr.len(), 2);
327 }
328}