1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hasher;
8
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayId;
13use vortex_array::ArrayParts;
14use vortex_array::ArrayRef;
15use vortex_array::ArraySlots;
16use vortex_array::ArrayView;
17use vortex_array::ExecutionCtx;
18use vortex_array::ExecutionResult;
19use vortex_array::IntoArray;
20use vortex_array::Precision;
21use vortex_array::TypedArrayRef;
22use vortex_array::arrays::BoolArray;
23use vortex_array::buffer::BufferHandle;
24use vortex_array::dtype::DType;
25use vortex_array::scalar::Scalar;
26use vortex_array::serde::ArrayChildren;
27use vortex_array::validity::Validity;
28use vortex_array::vtable::OperationsVTable;
29use vortex_array::vtable::VTable;
30use vortex_array::vtable::ValidityVTable;
31use vortex_array::vtable::child_to_validity;
32use vortex_array::vtable::validity_to_child;
33use vortex_buffer::BitBufferMut;
34use vortex_buffer::ByteBuffer;
35use vortex_error::VortexResult;
36use vortex_error::vortex_bail;
37use vortex_error::vortex_ensure;
38use vortex_error::vortex_panic;
39use vortex_session::VortexSession;
40use vortex_session::registry::CachedId;
41
42use crate::kernel::PARENT_KERNELS;
43
44pub type ByteBoolArray = Array<ByteBool>;
46
47impl ArrayHash for ByteBoolData {
48 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
49 self.buffer.array_hash(state, precision);
50 }
51}
52
53impl ArrayEq for ByteBoolData {
54 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
55 self.buffer.array_eq(&other.buffer, precision)
56 }
57}
58
59impl VTable for ByteBool {
60 type TypedArrayData = ByteBoolData;
61
62 type OperationsVTable = Self;
63 type ValidityVTable = Self;
64
65 fn id(&self) -> ArrayId {
66 static ID: CachedId = CachedId::new("vortex.bytebool");
67 *ID
68 }
69
70 fn validate(
71 &self,
72 data: &Self::TypedArrayData,
73 dtype: &DType,
74 len: usize,
75 slots: &[Option<ArrayRef>],
76 ) -> VortexResult<()> {
77 let validity = child_to_validity(slots[VALIDITY_SLOT].as_ref(), dtype.nullability());
78 ByteBoolData::validate(data.buffer(), &validity, dtype, len)
79 }
80
81 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
82 1
83 }
84
85 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
86 match idx {
87 0 => array.buffer().clone(),
88 _ => vortex_panic!("ByteBoolArray buffer index {idx} out of bounds"),
89 }
90 }
91
92 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
93 match idx {
94 0 => Some("values".to_string()),
95 _ => vortex_panic!("ByteBoolArray buffer_name index {idx} out of bounds"),
96 }
97 }
98
99 fn serialize(
100 _array: ArrayView<'_, Self>,
101 _session: &VortexSession,
102 ) -> VortexResult<Option<Vec<u8>>> {
103 Ok(Some(vec![]))
104 }
105
106 fn deserialize(
107 &self,
108 dtype: &DType,
109 len: usize,
110 metadata: &[u8],
111 buffers: &[BufferHandle],
112 children: &dyn ArrayChildren,
113 _session: &VortexSession,
114 ) -> VortexResult<ArrayParts<Self>> {
115 if !metadata.is_empty() {
116 vortex_bail!(
117 "ByteBoolArray expects empty metadata, got {} bytes",
118 metadata.len()
119 );
120 }
121 let validity = if children.is_empty() {
122 Validity::from(dtype.nullability())
123 } else if children.len() == 1 {
124 let validity = children.get(0, &Validity::DTYPE, len)?;
125 Validity::Array(validity)
126 } else {
127 vortex_bail!("Expected 0 or 1 child, got {}", children.len());
128 };
129
130 if buffers.len() != 1 {
131 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
132 }
133 let buffer = buffers[0].clone();
134
135 let data = ByteBoolData::new(buffer);
136 let slots = ByteBoolData::make_slots(&validity, len);
137 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
138 }
139
140 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
141 SLOT_NAMES[idx].to_string()
142 }
143
144 fn reduce_parent(
145 array: ArrayView<'_, Self>,
146 parent: &ArrayRef,
147 child_idx: usize,
148 ) -> VortexResult<Option<ArrayRef>> {
149 crate::rules::RULES.evaluate(array, parent, child_idx)
150 }
151
152 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
153 let boolean_buffer = BitBufferMut::from(array.truthy_bytes()).freeze();
155 let validity = array.validity()?;
156 Ok(ExecutionResult::done(
157 BoolArray::new(boolean_buffer, validity).into_array(),
158 ))
159 }
160
161 fn execute_parent(
162 array: ArrayView<'_, Self>,
163 parent: &ArrayRef,
164 child_idx: usize,
165 ctx: &mut ExecutionCtx,
166 ) -> VortexResult<Option<ArrayRef>> {
167 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
168 }
169}
170
171pub(super) const VALIDITY_SLOT: usize = 0;
173pub(super) const NUM_SLOTS: usize = 1;
174pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
175
176#[derive(Clone, Debug)]
177pub struct ByteBoolData {
178 buffer: BufferHandle,
179}
180
181impl Display for ByteBoolData {
182 fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
183 Ok(())
184 }
185}
186
187pub trait ByteBoolArrayExt: TypedArrayRef<ByteBool> {
188 fn validity(&self) -> Validity {
189 child_to_validity(
190 self.as_ref().slots()[VALIDITY_SLOT].as_ref(),
191 self.as_ref().dtype().nullability(),
192 )
193 }
194}
195
196impl<T: TypedArrayRef<ByteBool>> ByteBoolArrayExt for T {}
197
198#[derive(Clone, Debug)]
199pub struct ByteBool;
200
201impl ByteBool {
202 pub fn new(buffer: BufferHandle, validity: Validity) -> ByteBoolArray {
203 if let Some(len) = validity.maybe_len() {
204 assert_eq!(
205 buffer.len(),
206 len,
207 "ByteBool validity and bytes must have same length"
208 );
209 }
210 let dtype = DType::Bool(validity.nullability());
211
212 let slots = ByteBoolData::make_slots(&validity, buffer.len());
213 let data = ByteBoolData::new(buffer);
214 let len = data.len();
215 unsafe {
216 Array::from_parts_unchecked(
217 ArrayParts::new(ByteBool, dtype, len, data).with_slots(slots),
218 )
219 }
220 }
221
222 pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
224 let validity = validity.into();
225 let bytes: Vec<u8> = data.into_iter().map(|b| b as u8).collect();
227 let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
228 ByteBool::new(handle, validity)
229 }
230
231 pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
233 let validity = Validity::from_iter(data.iter().map(|v| v.is_some()));
234 let bytes: Vec<u8> = data
236 .into_iter()
237 .map(|b| b.unwrap_or_default() as u8)
238 .collect();
239 let handle = BufferHandle::new_host(ByteBuffer::from(bytes));
240 ByteBool::new(handle, validity)
241 }
242}
243
244impl ByteBoolData {
245 pub fn validate(
246 buffer: &BufferHandle,
247 validity: &Validity,
248 dtype: &DType,
249 len: usize,
250 ) -> VortexResult<()> {
251 let expected_dtype = DType::Bool(validity.nullability());
252 vortex_ensure!(
253 dtype == &expected_dtype,
254 "expected dtype {expected_dtype}, got {dtype}"
255 );
256 vortex_ensure!(
257 buffer.len() == len,
258 "expected len {len}, got {}",
259 buffer.len()
260 );
261 if let Some(vlen) = validity.maybe_len() {
262 vortex_ensure!(vlen == len, "expected validity len {len}, got {vlen}");
263 }
264 Ok(())
265 }
266
267 fn make_slots(validity: &Validity, len: usize) -> ArraySlots {
268 vec![validity_to_child(validity, len)].into()
269 }
270
271 pub fn new(buffer: BufferHandle) -> Self {
272 Self { buffer }
273 }
274
275 pub fn len(&self) -> usize {
277 self.buffer.len()
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.buffer.len() == 0
283 }
284
285 pub fn buffer(&self) -> &BufferHandle {
286 &self.buffer
287 }
288
289 pub fn truthy_bytes(&self) -> &[u8] {
293 self.buffer().as_host().as_slice()
294 }
295}
296
297impl ValidityVTable<ByteBool> for ByteBool {
298 fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
299 Ok(ByteBoolArrayExt::validity(&array))
300 }
301}
302
303impl OperationsVTable<ByteBool> for ByteBool {
304 fn scalar_at(
305 array: ArrayView<'_, ByteBool>,
306 index: usize,
307 _ctx: &mut ExecutionCtx,
308 ) -> VortexResult<Scalar> {
309 Ok(Scalar::bool(
310 array.buffer.as_host()[index] == 1,
311 array.dtype().nullability(),
312 ))
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use vortex_array::ArrayContext;
319 use vortex_array::IntoArray;
320 use vortex_array::LEGACY_SESSION;
321 use vortex_array::VortexSessionExecute;
322 use vortex_array::assert_arrays_eq;
323 use vortex_array::serde::SerializeOptions;
324 use vortex_array::serde::SerializedArray;
325 use vortex_array::session::ArraySession;
326 use vortex_array::session::ArraySessionExt;
327 use vortex_buffer::ByteBufferMut;
328 use vortex_session::VortexSession;
329 use vortex_session::registry::ReadContext;
330
331 use super::*;
332
333 #[test]
334 fn test_validity_construction() {
335 let v = vec![true, false];
336 let v_len = v.len();
337
338 let arr = ByteBool::from_vec(v, Validity::AllValid);
339 assert_eq!(v_len, arr.len());
340
341 let mut ctx = LEGACY_SESSION.create_execution_ctx();
342 for idx in 0..arr.len() {
343 assert!(arr.is_valid(idx, &mut ctx).unwrap());
344 }
345
346 let v = vec![Some(true), None, Some(false)];
347 let arr = ByteBool::from_option_vec(v);
348 assert!(arr.is_valid(0, &mut ctx).unwrap());
349 assert!(!arr.is_valid(1, &mut ctx).unwrap());
350 assert!(arr.is_valid(2, &mut ctx).unwrap());
351 assert_eq!(arr.len(), 3);
352
353 let v: Vec<Option<bool>> = vec![None, None];
354 let v_len = v.len();
355
356 let arr = ByteBool::from_option_vec(v);
357 assert_eq!(v_len, arr.len());
358
359 for idx in 0..arr.len() {
360 assert!(!arr.is_valid(idx, &mut ctx).unwrap());
361 }
362 assert_eq!(arr.len(), 2);
363 }
364
365 #[test]
366 fn test_nullable_bytebool_serde_roundtrip() {
367 let array = ByteBool::from_option_vec(vec![Some(true), None, Some(false), None]);
368 let dtype = array.dtype().clone();
369 let len = array.len();
370 let session = VortexSession::empty().with::<ArraySession>();
371 session.arrays().register(ByteBool);
372
373 let ctx = ArrayContext::empty();
374 let serialized = array
375 .clone()
376 .into_array()
377 .serialize(&ctx, &session, &SerializeOptions::default())
378 .unwrap();
379
380 let mut concat = ByteBufferMut::empty();
381 for buf in serialized {
382 concat.extend_from_slice(buf.as_ref());
383 }
384
385 let parts = SerializedArray::try_from(concat.freeze()).unwrap();
386 let decoded = parts
387 .decode(&dtype, len, &ReadContext::new(ctx.to_ids()), &session)
388 .unwrap();
389
390 assert_arrays_eq!(decoded, array);
391 }
392}