Skip to main content

vortex_alp/alp_rd/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#![allow(clippy::cast_possible_truncation)]
5
6pub use array::*;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::patches::Patches;
10use vortex_array::validity::Validity;
11use vortex_fastlanes::bitpack_compress::bitpack_encode_unchecked;
12
13mod array;
14mod compute;
15mod kernel;
16mod ops;
17mod rules;
18mod slice;
19
20use std::ops::Shl;
21use std::ops::Shr;
22
23use itertools::Itertools;
24use num_traits::Float;
25use num_traits::One;
26use num_traits::PrimInt;
27use rustc_hash::FxBuildHasher;
28use vortex_array::arrays::PrimitiveArray;
29use vortex_array::dtype::DType;
30use vortex_array::dtype::NativePType;
31use vortex_array::match_each_integer_ptype;
32use vortex_buffer::Buffer;
33use vortex_buffer::BufferMut;
34use vortex_error::VortexExpect;
35use vortex_error::VortexResult;
36use vortex_error::vortex_panic;
37use vortex_utils::aliases::hash_map::HashMap;
38
39use crate::match_each_alp_float_ptype;
40
41macro_rules! bit_width {
42    ($value:expr) => {
43        if $value == 0 {
44            1
45        } else {
46            $value.ilog2().wrapping_add(1) as usize
47        }
48    };
49}
50
51/// Max number of bits to cut from the MSB section of each float.
52const CUT_LIMIT: usize = 16;
53
54const MAX_DICT_SIZE: u8 = 8;
55
56mod private {
57    pub trait Sealed {}
58
59    impl Sealed for f32 {}
60    impl Sealed for f64 {}
61}
62
63/// Main trait for ALP-RD encodable floating point numbers.
64///
65/// Like the paper, we limit this to the IEEE7 754 single-precision (`f32`) and double-precision
66/// (`f64`) floating point types.
67pub trait ALPRDFloat: private::Sealed + Float + Copy + NativePType {
68    /// The unsigned integer type with the same bit-width as the floating-point type.
69    type UINT: NativePType + PrimInt + One + Copy;
70
71    /// Number of bits the value occupies in registers.
72    const BITS: usize = size_of::<Self>() * 8;
73
74    /// Bit-wise transmute from the unsigned integer type to the floating-point type.
75    fn from_bits(bits: Self::UINT) -> Self;
76
77    /// Bit-wise transmute into the unsigned integer type.
78    fn to_bits(value: Self) -> Self::UINT;
79
80    /// Truncating conversion from the unsigned integer type to `u16`.
81    fn to_u16(bits: Self::UINT) -> u16;
82
83    /// Type-widening conversion from `u16` to the unsigned integer type.
84    fn from_u16(value: u16) -> Self::UINT;
85}
86
87impl ALPRDFloat for f64 {
88    type UINT = u64;
89
90    fn from_bits(bits: Self::UINT) -> Self {
91        f64::from_bits(bits)
92    }
93
94    fn to_bits(value: Self) -> Self::UINT {
95        value.to_bits()
96    }
97
98    fn to_u16(bits: Self::UINT) -> u16 {
99        bits as u16
100    }
101
102    fn from_u16(value: u16) -> Self::UINT {
103        value as u64
104    }
105}
106
107impl ALPRDFloat for f32 {
108    type UINT = u32;
109
110    fn from_bits(bits: Self::UINT) -> Self {
111        f32::from_bits(bits)
112    }
113
114    fn to_bits(value: Self) -> Self::UINT {
115        value.to_bits()
116    }
117
118    fn to_u16(bits: Self::UINT) -> u16 {
119        bits as u16
120    }
121
122    fn from_u16(value: u16) -> Self::UINT {
123        value as u32
124    }
125}
126
127/// Encoder for ALP-RD ("real doubles") values.
128///
129/// The encoder calculates its parameters from a single sample of floating-point values,
130/// and then can be applied to many vectors.
131///
132/// ALP-RD uses the algorithm outlined in Section 3.4 of the paper. The crux of it is that the front
133/// (most significant) bits of many double vectors tend to be  the same, i.e. most doubles in a
134/// vector often use the same exponent and front bits. Compression proceeds by finding the best
135/// prefix of up to 16 bits that can be collapsed into a dictionary of
136/// up to 8 elements. Each double can then be broken into the front/left `L` bits, which neatly
137/// bit-packs down to 1-3 bits per element (depending on the actual dictionary size).
138/// The remaining `R` bits naturally bit-pack.
139///
140/// In the ideal case, this scheme allows us to store a sequence of doubles in 49 bits-per-value.
141///
142/// Our implementation draws on the MIT-licensed [C++ implementation] provided by the original authors.
143///
144/// [C++ implementation]: https://github.com/cwida/ALP/blob/main/include/alp/rd.hpp
145pub struct RDEncoder {
146    right_bit_width: u8,
147    codes: Vec<u16>,
148}
149
150impl RDEncoder {
151    /// Build a new encoder from a sample of doubles.
152    pub fn new<T>(sample: &[T]) -> Self
153    where
154        T: ALPRDFloat + NativePType,
155        T::UINT: NativePType,
156    {
157        let dictionary = find_best_dictionary::<T>(sample);
158
159        let mut codes = vec![0; dictionary.dictionary.len()];
160        dictionary.dictionary.into_iter().for_each(|(bits, code)| {
161            // write the reverse mapping into the codes vector.
162            codes[code as usize] = bits
163        });
164
165        Self {
166            right_bit_width: dictionary.right_bit_width,
167            codes,
168        }
169    }
170
171    /// Build a new encoder from known parameters.
172    pub fn from_parts(right_bit_width: u8, codes: Vec<u16>) -> Self {
173        Self {
174            right_bit_width,
175            codes,
176        }
177    }
178
179    /// Encode a set of floating point values with ALP-RD.
180    ///
181    /// Each value will be split into a left and right component, which are compressed individually.
182    // TODO(joe): make fallible
183    pub fn encode(&self, array: &PrimitiveArray) -> ALPRDArray {
184        match_each_alp_float_ptype!(array.ptype(), |P| { self.encode_generic::<P>(array) })
185    }
186
187    fn encode_generic<T>(&self, array: &PrimitiveArray) -> ALPRDArray
188    where
189        T: ALPRDFloat + NativePType,
190        T::UINT: NativePType,
191    {
192        assert!(
193            !self.codes.is_empty(),
194            "codes lookup table must be populated before RD encoding"
195        );
196
197        let doubles = array.as_slice::<T>();
198
199        let mut left_parts: BufferMut<u16> = BufferMut::with_capacity(doubles.len());
200        let mut right_parts: BufferMut<T::UINT> = BufferMut::with_capacity(doubles.len());
201        let mut exceptions_pos: BufferMut<u64> = BufferMut::with_capacity(doubles.len() / 4);
202        let mut exceptions: BufferMut<u16> = BufferMut::with_capacity(doubles.len() / 4);
203
204        // mask for right-parts
205        let right_mask = T::UINT::one().shl(self.right_bit_width as _) - T::UINT::one();
206        let max_code = self.codes.len() - 1;
207        let left_bit_width = bit_width!(max_code);
208
209        for v in doubles.iter().copied() {
210            right_parts.push(T::to_bits(v) & right_mask);
211            left_parts.push(<T as ALPRDFloat>::to_u16(
212                T::to_bits(v).shr(self.right_bit_width as _),
213            ));
214        }
215
216        // dict-encode the left-parts, keeping track of exceptions
217        for (idx, left) in left_parts.iter_mut().enumerate() {
218            // TODO: revisit if we need to change the branch order for perf.
219            if let Some(code) = self.codes.iter().position(|v| *v == *left) {
220                *left = code as u16;
221            } else {
222                exceptions.push(*left);
223                exceptions_pos.push(idx as _);
224
225                *left = 0u16;
226            }
227        }
228
229        // Bit-pack down the encoded left-parts array that have been dictionary encoded.
230        let primitive_left = PrimitiveArray::new(
231            left_parts,
232            array
233                .validity()
234                .vortex_expect("ALP RD validity should be derivable"),
235        );
236        // SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
237        let packed_left = unsafe {
238            bitpack_encode_unchecked(primitive_left, left_bit_width as _)
239                .vortex_expect("bitpack_encode_unchecked should succeed for left parts")
240                .into_array()
241        };
242
243        let primitive_right = PrimitiveArray::new(right_parts, Validity::NonNullable);
244        // SAFETY: by construction, all values in right_parts are right_bit_width + leading zeros.
245        let packed_right = unsafe {
246            bitpack_encode_unchecked(primitive_right, self.right_bit_width as _)
247                .vortex_expect("bitpack_encode_unchecked should succeed for right parts")
248                .into_array()
249        };
250
251        // Bit-pack the dict-encoded left-parts
252        // Bit-pack the right-parts
253        // Patches for exceptions.
254        let exceptions = (!exceptions_pos.is_empty()).then(|| {
255            let max_exc_pos = exceptions_pos.last().copied().unwrap_or_default();
256            let bw = bit_width!(max_exc_pos) as u8;
257
258            let exc_pos_array = PrimitiveArray::new(exceptions_pos, Validity::NonNullable);
259            // SAFETY: We calculate bw such that it is wide enough to hold the largest position index.
260            let packed_pos = unsafe {
261                bitpack_encode_unchecked(exc_pos_array, bw)
262                    .vortex_expect(
263                        "bitpack_encode_unchecked should succeed for exception positions",
264                    )
265                    .into_array()
266            };
267
268            Patches::new(
269                doubles.len(),
270                0,
271                packed_pos,
272                exceptions.into_array(),
273                // TODO(0ax1): handle chunk offsets
274                None,
275            )
276            .vortex_expect("Patches construction in encode")
277        });
278
279        ALPRD::try_new(
280            DType::Primitive(T::PTYPE, packed_left.dtype().nullability()),
281            packed_left,
282            Buffer::<u16>::copy_from(&self.codes),
283            packed_right,
284            self.right_bit_width,
285            exceptions,
286        )
287        .vortex_expect("ALPRDArray construction in encode")
288    }
289}
290
291/// Decode ALP-RD encoded values back into their original floating point format.
292///
293/// # Panics
294///
295/// Panics if `left_parts` and `right_parts` differ in length.
296pub fn alp_rd_decode<T: ALPRDFloat>(
297    mut left_parts: BufferMut<u16>,
298    left_parts_dict: &[u16],
299    right_bit_width: u8,
300    right_parts: BufferMut<T::UINT>,
301    left_parts_patches: Option<Patches>,
302    ctx: &mut ExecutionCtx,
303) -> VortexResult<Buffer<T>> {
304    if left_parts.len() != right_parts.len() {
305        vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len");
306    }
307
308    let shift = right_bit_width as usize;
309
310    if let Some(patches) = left_parts_patches {
311        // Patched path: some left-part codes map to exception values that live outside
312        // the dictionary. We must dictionary-decode first, then overwrite the exceptions,
313        // before we can combine with right-parts.
314
315        // Dictionary-decode every code in-place (code → actual left bit-pattern).
316        for code in left_parts.iter_mut() {
317            *code = left_parts_dict[*code as usize];
318        }
319
320        // Overwrite exception positions with their true left bit-patterns.
321        let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
322        let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
323        alp_rd_apply_patches(&mut left_parts, &indices, &patch_values, patches.offset());
324
325        // Reconstruct floats by shifting each decoded left value into the MSBs
326        // and OR-ing with the corresponding right value.
327        alp_rd_combine_inplace::<T>(
328            right_parts,
329            |right, &left| {
330                *right = (<T as ALPRDFloat>::from_u16(left) << shift) | *right;
331            },
332            left_parts.as_ref(),
333        )
334    } else {
335        // Non-patched fast path: every code maps through the dictionary, so we can
336        // pre-shift the entire dictionary once and reduce the per-element hot loop to
337        // a single table lookup + OR.
338        let mut shifted_dict = [T::UINT::default(); MAX_DICT_SIZE as usize];
339        for (i, &entry) in left_parts_dict.iter().enumerate() {
340            shifted_dict[i] = <T as ALPRDFloat>::from_u16(entry) << shift;
341        }
342
343        // Each element: look up the pre-shifted left value by code, OR with right-parts.
344        alp_rd_combine_inplace::<T>(
345            right_parts,
346            |right, &code| {
347                // SAFETY: codes are bounded by dict size (< left_parts_dict.len() <= MAX_DICT_SIZE).
348                *right = unsafe { *shifted_dict.get_unchecked(code as usize) } | *right;
349            },
350            left_parts.as_ref(),
351        )
352    }
353}
354
355/// Apply patches to the decoded left-parts values.
356fn alp_rd_apply_patches(
357    values: &mut BufferMut<u16>,
358    indices: &PrimitiveArray,
359    patch_values: &PrimitiveArray,
360    offset: usize,
361) {
362    match_each_integer_ptype!(indices.ptype(), |T| {
363        indices
364            .as_slice::<T>()
365            .iter()
366            .copied()
367            .map(|idx| idx - offset as T)
368            .zip(patch_values.as_slice::<u16>().iter())
369            .for_each(|(idx, v)| values[idx as usize] = *v);
370    })
371}
372
373/// Zip `right_parts` with `left_data`, apply `combine_fn` per element, then reinterpret the
374/// buffer from `T::UINT` to `T` (same bit-width: u32↔f32, u64↔f64).
375fn alp_rd_combine_inplace<T: ALPRDFloat>(
376    mut right_parts: BufferMut<T::UINT>,
377    combine_fn: impl Fn(&mut T::UINT, &u16),
378    left_data: &[u16],
379) -> VortexResult<Buffer<T>> {
380    for (right, left) in right_parts.as_mut_slice().iter_mut().zip(left_data.iter()) {
381        combine_fn(right, left);
382    }
383    // SAFETY: all bit patterns of T::UINT are valid T (u32↔f32 or u64↔f64).
384    Ok(unsafe { right_parts.transmute::<T>() }.freeze())
385}
386/// Find the best "cut point" for a set of floating point values such that we can
387/// cast them all to the relevant value instead.
388fn find_best_dictionary<T: ALPRDFloat>(samples: &[T]) -> ALPRDDictionary {
389    let mut best_est_size = f64::MAX;
390    let mut best_dict = ALPRDDictionary::default();
391
392    for p in 1..=16 {
393        let candidate_right_bw = (T::BITS - p) as u8;
394        let (dictionary, exception_count) =
395            build_left_parts_dictionary::<T>(samples, candidate_right_bw, MAX_DICT_SIZE);
396        let estimated_size = estimate_compression_size(
397            dictionary.right_bit_width,
398            dictionary.left_bit_width,
399            exception_count,
400            samples.len(),
401        );
402        if estimated_size < best_est_size {
403            best_est_size = estimated_size;
404            best_dict = dictionary;
405        }
406    }
407
408    best_dict
409}
410
411/// Build dictionary of the leftmost bits.
412fn build_left_parts_dictionary<T: ALPRDFloat>(
413    samples: &[T],
414    right_bw: u8,
415    max_dict_size: u8,
416) -> (ALPRDDictionary, usize) {
417    assert!(
418        right_bw >= (T::BITS - CUT_LIMIT) as _,
419        "left-parts must be <= 16 bits"
420    );
421
422    // Count the number of occurrences of each left bit pattern
423    let mut counts = HashMap::new();
424    samples
425        .iter()
426        .copied()
427        .map(|v| <T as ALPRDFloat>::to_u16(T::to_bits(v).shr(right_bw as _)))
428        .for_each(|item| *counts.entry(item).or_default() += 1);
429
430    // Sorted counts: sort by negative count so that heavy hitters sort first.
431    let mut sorted_bit_counts: Vec<(u16, usize)> = counts.into_iter().collect_vec();
432    sorted_bit_counts.sort_by_key(|(_, count)| count.wrapping_neg());
433
434    // Assign the most-frequently occurring left-bits as dictionary codes, up to `dict_size`...
435    let mut dictionary = HashMap::with_capacity_and_hasher(max_dict_size as _, FxBuildHasher);
436    let mut code = 0u16;
437    while code < (max_dict_size as _) && (code as usize) < sorted_bit_counts.len() {
438        let (bits, _) = sorted_bit_counts[code as usize];
439        dictionary.insert(bits, code);
440        code += 1;
441    }
442
443    // ...and the rest are exceptions.
444    let exception_count: usize = sorted_bit_counts
445        .iter()
446        .skip(code as _)
447        .map(|(_, count)| *count)
448        .sum();
449
450    // Left bit-width is determined based on the actual dictionary size.
451    let max_code = dictionary.len() - 1;
452    let left_bw = bit_width!(max_code) as u8;
453
454    (
455        ALPRDDictionary {
456            dictionary,
457            right_bit_width: right_bw,
458            left_bit_width: left_bw,
459        },
460        exception_count,
461    )
462}
463
464/// Estimate the bits-per-value when using these compression settings.
465fn estimate_compression_size(
466    right_bw: u8,
467    left_bw: u8,
468    exception_count: usize,
469    sample_n: usize,
470) -> f64 {
471    const EXC_POSITION_SIZE: usize = 16; // two bytes for exception position.
472    const EXC_SIZE: usize = 16; // two bytes for each exception (up to 16 front bits).
473
474    let exceptions_size = exception_count * (EXC_POSITION_SIZE + EXC_SIZE);
475    (right_bw as f64) + (left_bw as f64) + ((exceptions_size as f64) / (sample_n as f64))
476}
477
478/// The ALP-RD dictionary, encoding the "left parts" and their dictionary encoding.
479#[derive(Debug, Default)]
480struct ALPRDDictionary {
481    /// Items in the dictionary are bit patterns, along with their 16-bit encoding.
482    dictionary: HashMap<u16, u16, FxBuildHasher>,
483    /// The (compressed) left bit width. This is after bit-packing the dictionary codes.
484    left_bit_width: u8,
485    /// The right bit width. This is the bit-packed width of each of the "real double" values.
486    right_bit_width: u8,
487}