1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArrayView;
16use vortex_array::ExecutionCtx;
17use vortex_array::ExecutionResult;
18use vortex_array::IntoArray;
19use vortex_array::Precision;
20use vortex_array::TypedArrayRef;
21use vortex_array::arrays::BoolArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::scalar::Scalar;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityVTable;
30use vortex_array::vtable::child_to_validity;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBuffer;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_mask::Mask;
39use vortex_session::VortexSession;
40
41use crate::kernel::PARENT_KERNELS;
42
43pub type ByteBoolArray = Array<ByteBool>;
45
46impl ArrayHash for ByteBoolData {
47 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
48 self.buffer.array_hash(state, precision);
49 }
50}
51
52impl ArrayEq for ByteBoolData {
53 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
54 self.buffer.array_eq(&other.buffer, precision)
55 }
56}
57
58impl VTable for ByteBool {
59 type ArrayData = ByteBoolData;
60
61 type OperationsVTable = Self;
62 type ValidityVTable = Self;
63
64 fn id(&self) -> ArrayId {
65 Self::ID
66 }
67
68 fn validate(
69 &self,
70 data: &Self::ArrayData,
71 dtype: &DType,
72 len: usize,
73 slots: &[Option<ArrayRef>],
74 ) -> VortexResult<()> {
75 let validity = child_to_validity(&slots[VALIDITY_SLOT], dtype.nullability());
76 ByteBoolData::validate(data.buffer(), &validity, dtype, len)
77 }
78
79 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
80 1
81 }
82
83 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
84 match idx {
85 0 => array.buffer().clone(),
86 _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
87 }
88 }
89
90 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
91 match idx {
92 0 => Some("values".to_string()),
93 _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
94 }
95 }
96
97 fn serialize(
98 _array: ArrayView<'_, Self>,
99 _session: &VortexSession,
100 ) -> VortexResult<Option<Vec<u8>>> {
101 Ok(Some(vec![]))
102 }
103
104 fn deserialize(
105 &self,
106 dtype: &DType,
107 len: usize,
108 metadata: &[u8],
109 buffers: &[BufferHandle],
110 children: &dyn ArrayChildren,
111 _session: &VortexSession,
112 ) -> VortexResult<ArrayParts<Self>> {
113 if !metadata.is_empty() {
114 vortex_bail!(
115 "ByteBoolArray expects empty metadata, got {} bytes",
116 metadata.len()
117 );
118 }
119 let validity = if children.is_empty() {
120 Validity::from(dtype.nullability())
121 } else if children.len() == 1 {
122 let validity = children.get(0, &Validity::DTYPE, len)?;
123 Validity::Array(validity)
124 } else {
125 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
126 };
127
128 if buffers.len() != 1 {
129 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
130 }
131 let buffer = buffers[0].clone();
132
133 let data = ByteBoolData::new(buffer, validity.clone());
134 let slots = ByteBoolData::make_slots(&validity, len);
135 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
136 }
137
138 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
139 SLOT_NAMES[idx].to_string()
140 }
141
142 fn reduce_parent(
143 array: ArrayView<'_, Self>,
144 parent: &ArrayRef,
145 child_idx: usize,
146 ) -> VortexResult<Option<ArrayRef>> {
147 crate::rules::RULES.evaluate(array, parent, child_idx)
148 }
149
150 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
151 let boolean_buffer = BitBuffer::from(array.as_slice());
152 let validity = array.validity()?;
153 Ok(ExecutionResult::done(
154 BoolArray::new(boolean_buffer, validity).into_array(),
155 ))
156 }
157
158 fn execute_parent(
159 array: ArrayView<'_, Self>,
160 parent: &ArrayRef,
161 child_idx: usize,
162 ctx: &mut ExecutionCtx,
163 ) -> VortexResult<Option<ArrayRef>> {
164 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
165 }
166}
167
168pub(super) const VALIDITY_SLOT: usize = 0;
170pub(super) const NUM_SLOTS: usize = 1;
171pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
172
173#[derive(Clone, Debug)]
174pub struct ByteBoolData {
175 buffer: BufferHandle,
176}
177
178impl Display for ByteBoolData {
179 fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
180 Ok(())
181 }
182}
183
184pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
185 fn validity(&self) -> Validity {
186 child_to_validity(
187 &self.as_ref().slots()[VALIDITY_SLOT],
188 self.as_ref().dtype().nullability(),
189 )
190 }
191
192 fn validity_mask(&self) -> Mask {
193 self.validity().to_mask(self.as_ref().len())
194 }
195}
196
197impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
198
199#[derive(Clone, Debug)]
200pub struct ByteBool;
201
202impl ByteBool {
203 pub const ID: ArrayId = ArrayId::new_ref("vortex.bytebool");
204
205 pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
206 let dtype = DType::Bool(validity.nullability());
207 let slots = ByteBoolData::make_slots(&validity, buffer.len());
208 let data = ByteBoolData::new(buffer, validity);
209 let len = data.len();
210 unsafe {
211 Array::from_parts_unchecked(
212 ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
213 )
214 }
215 }
216
217 pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
219 let validity = validity.into();
220 let data = ByteBoolData::from_vec(data, validity.clone());
221 let dtype = DType::Bool(validity.nullability());
222 let len = data.len();
223 let slots = ByteBoolData::make_slots(&validity, len);
224 unsafe {
225 Array::from_parts_unchecked(
226 ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
227 )
228 }
229 }
230
231 pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
233 let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
234 let data = ByteBoolData::from(data);
235 let dtype = DType::Bool(validity.nullability());
236 let len = data.len();
237 let slots = ByteBoolData::make_slots(&validity, len);
238 unsafe {
239 Array::from_parts_unchecked(
240 ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
241 )
242 }
243 }
244}
245
246impl ByteBoolData {
247 pub fn validate(
248 buffer: &BufferHandle,
249 validity: &Validity,
250 dtype: &DType,
251 len: usize,
252 ) -> VortexResult<()> {
253 let expected_dtype = DType::Bool(validity.nullability());
254 vortex_ensure!(
255 dtype == &expected_dtype,
256 "expected dtype {expected_dtype}, got {dtype}"
257 );
258 vortex_ensure!(
259 buffer.len() == len,
260 "expected len {len}, got {}",
261 buffer.len()
262 );
263 if let Some(vlen) = validity.maybe_len() {
264 vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
265 }
266 Ok(())
267 }
268
269 fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
270 vec![validity_to_child(validity, len)]
271 }
272
273 pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
274 let length = buffer.len();
275 if let Some(vlen) = validity.maybe_len()
276 && length != vlen
277 {
278 vortex_panic!(
279 "Buffer length ({}) does not match validity length ({})",
280 length,
281 vlen
282 );
283 }
284 Self { buffer }
285 }
286
287 pub fn len(&self) -> usize {
289 self.buffer.len()
290 }
291
292 pub fn is_empty(&self) -> bool {
294 self.buffer.len() == 0
295 }
296
297 pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> Self {
299 let validity = validity.into();
300 let data: Vec<u8> = unsafe { std::mem::transmute(data) };
302 Self::new(BufferHandle::new_host(ByteBuffer::from(data)), validity)
303 }
304
305 pub fn buffer(&self) -> &BufferHandle {
306 &self.buffer
307 }
308
309 pub fn as_slice(&self) -> &[bool] {
310 unsafe { std::mem::transmute(self.buffer().as_host().as_slice()) }
312 }
313}
314
315impl ValidityVTable<ByteBool> for ByteBool {
316 fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
317 Ok(ByteBoolArrayExt::validity(&array))
318 }
319}
320
321impl OperationsVTable<ByteBool> for ByteBool {
322 fn scalar_at(
323 array: ArrayView<'_, ByteBool>,
324 index: usize,
325 _ctx: &mut ExecutionCtx,
326 ) -> VortexResult<Scalar> {
327 Ok(Scalar::bool(
328 array.buffer.as_host()[index] == 1,
329 array.dtype().nullability(),
330 ))
331 }
332}
333
334impl From<Vec<bool>> for ByteBoolData {
335 fn from(value: Vec<bool>) -> Self {
336 Self::from_vec(value, Validity::AllValid)
337 }
338}
339
340impl From<Vec<Option<bool>>> for ByteBoolData {
341 fn from(value: Vec<Option<bool>>) -> Self {
342 let validity = Validity::from_iter(value.iter().map(|v| v.is_some()));
343
344 let data = value.into_iter().map(Option::unwrap_or_default).collect();
346
347 Self::from_vec(data, validity)
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use vortex_array::ArrayContext;
354 use vortex_array::IntoArray;
355 use vortex_array::LEGACY_SESSION;
356 use vortex_array::assert_arrays_eq;
357 use vortex_array::serde::SerializeOptions;
358 use vortex_array::serde::SerializedArray;
359 use vortex_array::session::ArraySession;
360 use vortex_array::session::ArraySessionExt;
361 use vortex_buffer::ByteBufferMut;
362 use vortex_session::VortexSession;
363 use vortex_session::registry::ReadContext;
364
365 use super::*;
366
367 #[test]
368 fn test_validity_construction() {
369 let v = vec![true, false];
370 let v_len = v.len();
371
372 let arr = ByteBool::from_vec(v, Validity::AllValid);
373 assert_eq!(v_len, arr.len());
374
375 for idx in 0..arr.len() {
376 assert!(arr.is_valid(idx).unwrap());
377 }
378
379 let v = vec![Some(true), None, Some(false)];
380 let arr = ByteBool::from_option_vec(v);
381 assert!(arr.is_valid(0).unwrap());
382 assert!(!arr.is_valid(1).unwrap());
383 assert!(arr.is_valid(2).unwrap());
384 assert_eq!(arr.len(), 3);
385
386 let v: Vec<Option<bool>> = vec![None, None];
387 let v_len = v.len();
388
389 let arr = ByteBool::from_option_vec(v);
390 assert_eq!(v_len, arr.len());
391
392 for idx in 0..arr.len() {
393 assert!(!arr.is_valid(idx).unwrap());
394 }
395 assert_eq!(arr.len(), 2);
396 }
397
398 #[test]
399 fn test_nullable_bytebool_serde_roundtrip() {
400 let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
401 let dtype = array.dtype().clone();
402 let len = array.len();
403 let session = VortexSession::empty().with::<ArraySession>();
404 session.arrays().register(ByteBool);
405
406 let ctx = ArrayContext::empty();
407 let serialized = array
408 .clone()
409 .into_array()
410 .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
411 .unwrap();
412
413 let mut concat = ByteBufferMut::empty();
414 for buf in serialized {
415 concat.extend_from_slice(buf.as_ref());
416 }
417
418 let parts = SerializedArray::try_from(concat.freeze()).unwrap();
419 let decoded = parts
420 .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
421 .unwrap();
422
423 assert_arrays_eq!(decoded, array);
424 }
425}