1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArrayView;
16use vortex_array::ExecutionCtx;
17use vortex_array::ExecutionResult;
18use vortex_array::IntoArray;
19use vortex_array::Precision;
20use vortex_array::TypedArrayRef;
21use vortex_array::arrays::BoolArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::scalar::Scalar;
25use vortex_array::serde::ArrayChildren;
26use vortex_array::validity::Validity;
27use vortex_array::vtable::OperationsVTable;
28use vortex_array::vtable::VTable;
29use vortex_array::vtable::ValidityVTable;
30use vortex_array::vtable::child_to_validity;
31use vortex_array::vtable::validity_to_child;
32use vortex_buffer::BitBufferMut;
33use vortex_buffer::ByteBuffer;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_ensure;
37use vortex_error::vortex_panic;
38use vortex_session::VortexSession;
39use vortex_session::registry::CachedId;
40
41use crate::kernel::PARENT_KERNELS;
42
43pub type ByteBoolArray = Array<ByteBool>;
45
46impl ArrayHash for ByteBoolData {
47 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
48 self.buffer.array_hash(state, precision);
49 }
50}
51
52impl ArrayEq for ByteBoolData {
53 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
54 self.buffer.array_eq(&other.buffer, precision)
55 }
56}
57
58impl VTable for ByteBool {
59 type ArrayData = ByteBoolData;
60
61 type OperationsVTable = Self;
62 type ValidityVTable = Self;
63
64 fn id(&self) -> ArrayId {
65 static ID: CachedId = CachedId::new("vortex.bytebool");
66 *ID
67 }
68
69 fn validate(
70 &self,
71 data: &Self::ArrayData,
72 dtype: &DType,
73 len: usize,
74 slots: &[Option<ArrayRef>],
75 ) -> VortexResult<()> {
76 let validity = child_to_validity(slots[VALIDITY_SLOT].as_ref(), dtype.nullability());
77 ByteBoolData::validate(data.buffer(), &validity, dtype, len)
78 }
79
80 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
81 1
82 }
83
84 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
85 match idx {
86 0 => array.buffer().clone(),
87 _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
88 }
89 }
90
91 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
92 match idx {
93 0 => Some("values".to_string()),
94 _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
95 }
96 }
97
98 fn serialize(
99 _array: ArrayView<'_, Self>,
100 _session: &VortexSession,
101 ) -> VortexResult<Option<Vec<u8>>> {
102 Ok(Some(vec![]))
103 }
104
105 fn deserialize(
106 &self,
107 dtype: &DType,
108 len: usize,
109 metadata: &[u8],
110 buffers: &[BufferHandle],
111 children: &dyn ArrayChildren,
112 _session: &VortexSession,
113 ) -> VortexResult<ArrayParts<Self>> {
114 if !metadata.is_empty() {
115 vortex_bail!(
116 "ByteBoolArray expects empty metadata, got {} bytes",
117 metadata.len()
118 );
119 }
120 let validity = if children.is_empty() {
121 Validity::from(dtype.nullability())
122 } else if children.len() == 1 {
123 let validity = children.get(0, &Validity::DTYPE, len)?;
124 Validity::Array(validity)
125 } else {
126 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
127 };
128
129 if buffers.len() != 1 {
130 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
131 }
132 let buffer = buffers[0].clone();
133
134 let data = ByteBoolData::new(buffer);
135 let slots = ByteBoolData::make_slots(&validity, len);
136 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
137 }
138
139 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
140 SLOT_NAMES[idx].to_string()
141 }
142
143 fn reduce_parent(
144 array: ArrayView<'_, Self>,
145 parent: &ArrayRef,
146 child_idx: usize,
147 ) -> VortexResult<Option<ArrayRef>> {
148 crate::rules::RULES.evaluate(array, parent, child_idx)
149 }
150
151 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
152 let boolean_buffer = BitBufferMut::from(array.truthy_bytes()).freeze();
154 let validity = array.validity()?;
155 Ok(ExecutionResult::done(
156 BoolArray::new(boolean_buffer, validity).into_array(),
157 ))
158 }
159
160 fn execute_parent(
161 array: ArrayView<'_, Self>,
162 parent: &ArrayRef,
163 child_idx: usize,
164 ctx: &mut ExecutionCtx,
165 ) -> VortexResult<Option<ArrayRef>> {
166 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
167 }
168}
169
170pub(super) const VALIDITY_SLOT: usize = 0;
172pub(super) const NUM_SLOTS: usize = 1;
173pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
174
175#[derive(Clone, Debug)]
176pub struct ByteBoolData {
177 buffer: BufferHandle,
178}
179
180impl Display for ByteBoolData {
181 fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
182 Ok(())
183 }
184}
185
186pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
187 fn validity(&self) -> Validity {
188 child_to_validity(
189 self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
190 self.as_ref().dtype().nullability(),
191 )
192 }
193}
194
195impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
196
197#[derive(Clone, Debug)]
198pub struct ByteBool;
199
200impl ByteBool {
201 pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
202 if let Some(len) = validity.maybe_len() {
203 assert_eq!(
204 buffer.len(),
205 len,
206 "ByteBool validity and bytes must have same length"
207 );
208 }
209 let dtype = DType::Bool(validity.nullability());
210
211 let slots = ByteBoolData::make_slots(&validity, buffer.len());
212 let data = ByteBoolData::new(buffer);
213 let len = data.len();
214 unsafe {
215 Array::from_parts_unchecked(
216 ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
217 )
218 }
219 }
220
221 pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
223 let validity = validity.into();
224 let bytes: Vec<u8> = data.into_iter().map(|b| b as u8).collect();
226 let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
227 ByteBool::new(handle, validity)
228 }
229
230 pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
232 let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
233 let bytes: Vec<u8> = data
235 .into_iter()
236 .map(|b| b.unwrap_or_default() as u8)
237 .collect();
238 let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
239 ByteBool::new(handle, validity)
240 }
241}
242
243impl ByteBoolData {
244 pub fn validate(
245 buffer: &BufferHandle,
246 validity: &Validity,
247 dtype: &DType,
248 len: usize,
249 ) -> VortexResult<()> {
250 let expected_dtype = DType::Bool(validity.nullability());
251 vortex_ensure!(
252 dtype == &expected_dtype,
253 "expected dtype {expected_dtype}, got {dtype}"
254 );
255 vortex_ensure!(
256 buffer.len() == len,
257 "expected len {len}, got {}",
258 buffer.len()
259 );
260 if let Some(vlen) = validity.maybe_len() {
261 vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
262 }
263 Ok(())
264 }
265
266 fn make_slots(validity: &Validity, len: usize) -> Vec<Option<ArrayRef>> {
267 vec![validity_to_child(validity, len)]
268 }
269
270 pub fn new(buffer: BufferHandle) -> Self {
271 Self { buffer }
272 }
273
274 pub fn len(&self) -> usize {
276 self.buffer.len()
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.buffer.len() == 0
282 }
283
284 pub fn buffer(&self) -> &BufferHandle {
285 &self.buffer
286 }
287
288 pub fn truthy_bytes(&self) -> &[u8] {
292 self.buffer().as_host().as_slice()
293 }
294}
295
296impl ValidityVTable<ByteBool> for ByteBool {
297 fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
298 Ok(ByteBoolArrayExt::validity(&array))
299 }
300}
301
302impl OperationsVTable<ByteBool> for ByteBool {
303 fn scalar_at(
304 array: ArrayView<'_, ByteBool>,
305 index: usize,
306 _ctx: &mut ExecutionCtx,
307 ) -> VortexResult<Scalar> {
308 Ok(Scalar::bool(
309 array.buffer.as_host()[index] == 1,
310 array.dtype().nullability(),
311 ))
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use vortex_array::ArrayContext;
318 use vortex_array::IntoArray;
319 use vortex_array::LEGACY_SESSION;
320 use vortex_array::VortexSessionExecute;
321 use vortex_array::assert_arrays_eq;
322 use vortex_array::serde::SerializeOptions;
323 use vortex_array::serde::SerializedArray;
324 use vortex_array::session::ArraySession;
325 use vortex_array::session::ArraySessionExt;
326 use vortex_buffer::ByteBufferMut;
327 use vortex_session::VortexSession;
328 use vortex_session::registry::ReadContext;
329
330 use super::*;
331
332 #[test]
333 fn test_validity_construction() {
334 let v = vec![true, false];
335 let v_len = v.len();
336
337 let arr = ByteBool::from_vec(v, Validity::AllValid);
338 assert_eq!(v_len, arr.len());
339
340 let mut ctx = LEGACY_SESSION.create_execution_ctx();
341 for idx in 0..arr.len() {
342 assert!(arr.is_valid(idx, &mut ctx).unwrap());
343 }
344
345 let v = vec![Some(true), None, Some(false)];
346 let arr = ByteBool::from_option_vec(v);
347 assert!(arr.is_valid(0, &mut ctx).unwrap());
348 assert!(!arr.is_valid(1, &mut ctx).unwrap());
349 assert!(arr.is_valid(2, &mut ctx).unwrap());
350 assert_eq!(arr.len(), 3);
351
352 let v: Vec<Option<bool>> = vec![None, None];
353 let v_len = v.len();
354
355 let arr = ByteBool::from_option_vec(v);
356 assert_eq!(v_len, arr.len());
357
358 for idx in 0..arr.len() {
359 assert!(!arr.is_valid(idx, &mut ctx).unwrap());
360 }
361 assert_eq!(arr.len(), 2);
362 }
363
364 #[test]
365 fn test_nullable_bytebool_serde_roundtrip() {
366 let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
367 let dtype = array.dtype().clone();
368 let len = array.len();
369 let session = VortexSession::empty().with::<ArraySession>();
370 session.arrays().register(ByteBool);
371
372 let ctx = ArrayContext::empty();
373 let serialized = array
374 .clone()
375 .into_array()
376 .serialize(&ctx, &session, &SerializeOptions::default())
377 .unwrap();
378
379 let mut concat = ByteBufferMut::empty();
380 for buf in serialized {
381 concat.extend_from_slice(buf.as_ref());
382 }
383
384 let parts = SerializedArray::try_from(concat.freeze()).unwrap();
385 let decoded = parts
386 .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
387 .unwrap();
388
389 assert_arrays_eq!(decoded, array);
390 }
391}