Skip to main content

vortex_fastlanes/bitpacking/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::hash::Hasher;
6
7use prost::Message;
8use vortex_array::Array;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayId;
12use vortex_array::ArrayParts;
13use vortex_array::ArrayRef;
14use vortex_array::ArrayView;
15use vortex_array::ExecutionCtx;
16use vortex_array::ExecutionResult;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::builders::ArrayBuilder;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::PType;
23use vortex_array::match_each_integer_ptype;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::require_patches;
27use vortex_array::require_validity;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::validity::Validity;
30use vortex_array::vtable::VTable;
31use vortex_array::vtable::child_to_validity;
32use vortex_array::vtable::validity_to_child;
33use vortex_error::VortexExpect;
34use vortex_error::VortexResult;
35use vortex_error::vortex_bail;
36use vortex_error::vortex_err;
37use vortex_error::vortex_panic;
38use vortex_session::VortexSession;
39use vortex_session::registry::CachedId;
40
41use crate::BitPackedArrayExt;
42use crate::BitPackedData;
43use crate::BitPackedDataParts;
44use crate::bitpack_decompress::unpack_array;
45use crate::bitpack_decompress::unpack_into_primitive_builder;
46use crate::bitpacking::array::BitPackedSlots;
47use crate::bitpacking::array::BitPackedSlotsView;
48use crate::bitpacking::vtable::kernels::PARENT_KERNELS;
49use crate::bitpacking::vtable::rules::RULES;
50mod kernels;
51mod operations;
52mod rules;
53mod validity;
54
55/// A [`BitPacked`]-encoded Vortex array.
56pub type BitPackedArray = Array<BitPacked>;
57
58#[derive(Clone, prost::Message)]
59pub struct BitPackedMetadata {
60    #[prost(uint32, tag = "1")]
61    pub(crate) bit_width: u32,
62    #[prost(uint32, tag = "2")]
63    pub(crate) offset: u32, // must be <1024
64    #[prost(message, optional, tag = "3")]
65    pub(crate) patches: Option<PatchesMetadata>,
66}
67
68impl ArrayHash for BitPackedData {
69    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
70        self.offset.hash(state);
71        self.bit_width.hash(state);
72        self.packed.array_hash(state, precision);
73        self.patch_offset.hash(state);
74        self.patch_offset_within_chunk.hash(state);
75    }
76}
77
78impl ArrayEq for BitPackedData {
79    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
80        self.offset == other.offset
81            && self.bit_width == other.bit_width
82            && self.packed.array_eq(&other.packed, precision)
83            && self.patch_offset == other.patch_offset
84            && self.patch_offset_within_chunk == other.patch_offset_within_chunk
85    }
86}
87
88impl VTable for BitPacked {
89    type ArrayData = BitPackedData;
90
91    type OperationsVTable = Self;
92    type ValidityVTable = Self;
93
94    fn id(&self) -> ArrayId {
95        static ID: CachedId = CachedId::new("fastlanes.bitpacked");
96        *ID
97    }
98
99    fn validate(
100        &self,
101        data: &Self::ArrayData,
102        dtype: &DType,
103        len: usize,
104        slots: &[Option<ArrayRef>],
105    ) -> VortexResult<()> {
106        let slots = BitPackedSlotsView::from_slots(slots);
107
108        let validity = child_to_validity(&slots.validity_child.cloned(), dtype.nullability());
109        let patches = match (slots.patch_indices, slots.patch_values) {
110            (Some(indices), Some(values)) => {
111                let patch_offset = data
112                    .patch_offset
113                    .vortex_expect("has patch slots but no patch_offset");
114                Some(unsafe {
115                    Patches::new_unchecked(
116                        len,
117                        patch_offset,
118                        indices.clone(),
119                        values.clone(),
120                        slots.patch_chunk_offsets.cloned(),
121                        data.patch_offset_within_chunk,
122                    )
123                })
124            }
125            _ => None,
126        };
127        BitPackedData::validate(
128            &data.packed,
129            dtype.as_ptype(),
130            &validity,
131            patches.as_ref(),
132            data.bit_width,
133            len,
134            data.offset,
135        )
136    }
137
138    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
139        1
140    }
141
142    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
143        match idx {
144            0 => array.packed().clone(),
145            _ => vortex_panic!("BitPackedArray buffer index {idx} out of bounds"),
146        }
147    }
148
149    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
150        match idx {
151            0 => Some("packed".to_string()),
152            _ => None,
153        }
154    }
155
156    fn reduce_parent(
157        array: ArrayView<'_, Self>,
158        parent: &ArrayRef,
159        child_idx: usize,
160    ) -> VortexResult<Option<ArrayRef>> {
161        RULES.evaluate(array, parent, child_idx)
162    }
163
164    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
165        BitPackedSlots::NAMES[idx].to_string()
166    }
167
168    fn serialize(
169        array: ArrayView<'_, Self>,
170        _session: &VortexSession,
171    ) -> VortexResult<Option<Vec<u8>>> {
172        Ok(Some(
173            BitPackedMetadata {
174                bit_width: array.bit_width() as u32,
175                offset: array.offset() as u32,
176                patches: array
177                    .patches()
178                    .map(|p| p.to_metadata(array.len(), array.dtype()))
179                    .transpose()?,
180            }
181            .encode_to_vec(),
182        ))
183    }
184
185    fn deserialize(
186        &self,
187        dtype: &DType,
188        len: usize,
189        metadata: &[u8],
190        buffers: &[BufferHandle],
191        children: &dyn ArrayChildren,
192        _session: &VortexSession,
193    ) -> VortexResult<ArrayParts<Self>> {
194        let metadata = BitPackedMetadata::decode(metadata)?;
195        if buffers.len() != 1 {
196            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
197        }
198        let packed = buffers[0].clone();
199
200        let load_validity = |child_idx: usize| {
201            if children.len() == child_idx {
202                Ok(Validity::from(dtype.nullability()))
203            } else if children.len() == child_idx + 1 {
204                let validity = children.get(child_idx, &Validity::DTYPE, len)?;
205                Ok(Validity::Array(validity))
206            } else {
207                vortex_bail!(
208                    "Expected {} or {} children, got {}",
209                    child_idx,
210                    child_idx + 1,
211                    children.len()
212                );
213            }
214        };
215
216        let validity_idx = match &metadata.patches {
217            None => 0,
218            Some(patches_meta) if patches_meta.chunk_offsets_dtype()?.is_some() => 3,
219            Some(_) => 2,
220        };
221
222        let validity = load_validity(validity_idx)?;
223
224        let patches = metadata
225            .patches
226            .map(|p| {
227                let indices = children.get(0, &p.indices_dtype()?, p.len()?)?;
228                let values = children.get(1, dtype, p.len()?)?;
229                let chunk_offsets = p
230                    .chunk_offsets_dtype()?
231                    .map(|dtype| children.get(2, &dtype, p.chunk_offsets_len() as usize))
232                    .transpose()?;
233
234                Patches::new(len, p.offset()?, indices, values, chunk_offsets)
235            })
236            .transpose()?;
237
238        let slots = {
239            let (pi, pv, pco) = match &patches {
240                Some(p) => (
241                    Some(p.indices().clone()),
242                    Some(p.values().clone()),
243                    p.chunk_offsets().clone(),
244                ),
245                None => (None, None, None),
246            };
247            let validity_slot = validity_to_child(&validity, len);
248            vec![pi, pv, pco, validity_slot]
249        };
250        let data = BitPackedData::try_new(
251            packed,
252            patches,
253            u8::try_from(metadata.bit_width).map_err(|_| {
254                vortex_err!(
255                    "BitPackedMetadata bit_width {} does not fit in u8",
256                    metadata.bit_width
257                )
258            })?,
259            u16::try_from(metadata.offset).map_err(|_| {
260                vortex_err!(
261                    "BitPackedMetadata offset {} does not fit in u16",
262                    metadata.offset
263                )
264            })?,
265        )?;
266        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
267    }
268
269    fn append_to_builder(
270        array: ArrayView<'_, Self>,
271        builder: &mut dyn ArrayBuilder,
272        ctx: &mut ExecutionCtx,
273    ) -> VortexResult<()> {
274        match_each_integer_ptype!(array.dtype().as_ptype(), |T| {
275            unpack_into_primitive_builder::<T>(
276                array,
277                builder
278                    .as_any_mut()
279                    .downcast_mut()
280                    .vortex_expect("bit packed array must canonicalize into a primitive array"),
281                ctx,
282            )
283        })
284    }
285
286    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
287        require_patches!(
288            array,
289            BitPackedSlots::PATCH_INDICES,
290            BitPackedSlots::PATCH_VALUES,
291            BitPackedSlots::PATCH_CHUNK_OFFSETS
292        );
293        require_validity!(array, BitPackedSlots::VALIDITY_CHILD);
294
295        Ok(ExecutionResult::done(
296            unpack_array(array.as_view(), ctx)?.into_array(),
297        ))
298    }
299
300    fn execute_parent(
301        array: ArrayView<'_, Self>,
302        parent: &ArrayRef,
303        child_idx: usize,
304        ctx: &mut ExecutionCtx,
305    ) -> VortexResult<Option<ArrayRef>> {
306        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
307    }
308}
309
310#[derive(Clone, Debug)]
311pub struct BitPacked;
312
313impl BitPacked {
314    pub fn try_new(
315        packed: BufferHandle,
316        ptype: PType,
317        validity: Validity,
318        patches: Option<Patches>,
319        bit_width: u8,
320        len: usize,
321        offset: u16,
322    ) -> VortexResult<BitPackedArray> {
323        let dtype = DType::Primitive(ptype, validity.nullability());
324        let slots = {
325            let (pi, pv, pco) = match &patches {
326                Some(p) => (
327                    Some(p.indices().clone()),
328                    Some(p.values().clone()),
329                    p.chunk_offsets().clone(),
330                ),
331                None => (None, None, None),
332            };
333            let validity_slot = validity_to_child(&validity, len);
334            vec![pi, pv, pco, validity_slot]
335        };
336        let data = BitPackedData::try_new(packed, patches, bit_width, offset)?;
337        Array::try_from_parts(ArrayParts::new(BitPacked, dtype, len, data).with_slots(slots))
338    }
339
340    pub fn into_parts(array: BitPackedArray) -> BitPackedDataParts {
341        let len = array.len();
342        let patches = array.patches();
343        let validity = array.validity().vortex_expect("BitPacked validity");
344        let data = array.into_data();
345        BitPackedDataParts {
346            offset: data.offset,
347            bit_width: data.bit_width,
348            len,
349            packed: data.packed,
350            patches,
351            validity,
352        }
353    }
354
355    /// Encode an array into a bitpacked representation with the given bit width.
356    pub fn encode(array: &ArrayRef, bit_width: u8) -> VortexResult<BitPackedArray> {
357        BitPackedData::encode(array, bit_width)
358    }
359}