Skip to main content

vortex_array/arrays/bool/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::hash::Hasher;
6
7use kernel::PARENT_KERNELS;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::ExecutionResult;
19use crate::array::Array;
20use crate::array::ArrayView;
21use crate::array::VTable;
22use crate::arrays::bool::BoolData;
23use crate::arrays::bool::array::SLOT_NAMES;
24use crate::buffer::BufferHandle;
25use crate::dtype::DType;
26use crate::serde::ArrayChildren;
27use crate::validity::Validity;
28mod canonical;
29mod kernel;
30mod operations;
31mod validity;
32
33use crate::Precision;
34use crate::array::ArrayId;
35use crate::arrays::bool::compute::rules::RULES;
36use crate::hash::ArrayEq;
37use crate::hash::ArrayHash;
38
39/// A [`Bool`]-encoded Vortex array.
40pub type BoolArray = Array<Bool>;
41
42#[derive(prost::Message)]
43pub struct BoolMetadata {
44    // The offset in bits must be <8
45    #[prost(uint32, tag = "1")]
46    pub offset: u32,
47}
48
49impl ArrayHash for BoolData {
50    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
51        self.bits.array_hash(state, precision);
52        self.offset.hash(state);
53    }
54}
55
56impl ArrayEq for BoolData {
57    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
58        self.offset == other.offset && self.bits.array_eq(&other.bits, precision)
59    }
60}
61
62impl VTable for Bool {
63    type ArrayData = BoolData;
64
65    type OperationsVTable = Self;
66    type ValidityVTable = Self;
67
68    fn id(&self) -> ArrayId {
69        Self::ID
70    }
71
72    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
73        1
74    }
75
76    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
77        match idx {
78            0 => array.bits.clone(),
79            _ => vortex_panic!("BoolArray buffer index {idx} out of bounds"),
80        }
81    }
82
83    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
84        match idx {
85            0 => Some("bits".to_string()),
86            _ => None,
87        }
88    }
89
90    fn serialize(
91        array: ArrayView<'_, Self>,
92        _session: &VortexSession,
93    ) -> VortexResult<Option<Vec<u8>>> {
94        assert!(array.offset < 8, "Offset must be <8, got {}", array.offset);
95        Ok(Some(
96            BoolMetadata {
97                offset: u32::try_from(array.offset).vortex_expect("checked"),
98            }
99            .encode_to_vec(),
100        ))
101    }
102
103    fn validate(
104        &self,
105        data: &BoolData,
106        dtype: &DType,
107        len: usize,
108        slots: &[Option<ArrayRef>],
109    ) -> VortexResult<()> {
110        let DType::Bool(nullability) = dtype else {
111            vortex_bail!("Expected bool dtype, got {dtype:?}");
112        };
113        vortex_ensure!(
114            data.bits.len() * 8 >= data.offset + len,
115            "BoolArray buffer with offset {} cannot back outer length {} (buffer bits = {})",
116            data.offset,
117            len,
118            data.bits.len() * 8
119        );
120
121        let validity = crate::array::child_to_validity(&slots[0], *nullability);
122        if let Some(validity_len) = validity.maybe_len() {
123            vortex_ensure!(
124                validity_len == len,
125                "BoolArray validity len {} does not match outer length {}",
126                validity_len,
127                len
128            );
129        }
130
131        Ok(())
132    }
133
134    fn deserialize(
135        &self,
136        dtype: &DType,
137        len: usize,
138        metadata: &[u8],
139
140        buffers: &[BufferHandle],
141        children: &dyn ArrayChildren,
142        _session: &VortexSession,
143    ) -> VortexResult<crate::array::ArrayParts<Self>> {
144        let metadata = BoolMetadata::decode(metadata)?;
145        if buffers.len() != 1 {
146            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
147        }
148
149        let validity = if children.is_empty() {
150            Validity::from(dtype.nullability())
151        } else if children.len() == 1 {
152            let validity = children.get(0, &Validity::DTYPE, len)?;
153            Validity::Array(validity)
154        } else {
155            vortex_bail!("Expected 0 or 1 child, got {}", children.len());
156        };
157
158        let buffer = buffers[0].clone();
159        let slots = BoolData::make_slots(&validity, len);
160        let data = BoolData::try_new_from_handle(buffer, metadata.offset as usize, len, validity)?;
161        Ok(crate::array::ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
162    }
163
164    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
165        SLOT_NAMES[idx].to_string()
166    }
167
168    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
169        Ok(ExecutionResult::done(array))
170    }
171
172    fn execute_parent(
173        array: ArrayView<'_, Self>,
174        parent: &ArrayRef,
175        child_idx: usize,
176        ctx: &mut ExecutionCtx,
177    ) -> VortexResult<Option<ArrayRef>> {
178        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
179    }
180
181    fn reduce_parent(
182        array: ArrayView<'_, Self>,
183        parent: &ArrayRef,
184        child_idx: usize,
185    ) -> VortexResult<Option<ArrayRef>> {
186        RULES.evaluate(array, parent, child_idx)
187    }
188}
189
190#[derive(Clone, Debug)]
191pub struct Bool;
192
193impl Bool {
194    pub const ID: ArrayId = ArrayId::new_ref("vortex.bool");
195}
196
197#[cfg(test)]
198mod tests {
199    use vortex_buffer::ByteBufferMut;
200    use vortex_session::registry::ReadContext;
201
202    use crate::ArrayContext;
203    use crate::IntoArray;
204    use crate::LEGACY_SESSION;
205    use crate::arrays::BoolArray;
206    use crate::assert_arrays_eq;
207    use crate::serde::SerializeOptions;
208    use crate::serde::SerializedArray;
209
210    #[test]
211    fn test_nullable_bool_serde_roundtrip() {
212        let array = BoolArray::from_iter([Some(true), None, Some(false), None]);
213        let dtype = array.dtype().clone();
214        let len = array.len();
215
216        let ctx = ArrayContext::empty();
217        let serialized = array
218            .clone()
219            .into_array()
220            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
221            .unwrap();
222
223        let mut concat = ByteBufferMut::empty();
224        for buf in serialized {
225            concat.extend_from_slice(buf.as_ref());
226        }
227        let parts = SerializedArray::try_from(concat.freeze()).unwrap();
228        let decoded = parts
229            .decode(
230                &dtype,
231                len,
232                &ReadContext::new(ctx.to_ids()),
233                &LEGACY_SESSION,
234            )
235            .unwrap();
236
237        assert_arrays_eq!(decoded, array);
238    }
239}