Skip to main content

vortex_alp/alp_rd/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#![expect(clippy::cast_possible_truncation)]
5
6pub use array::*;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::patches::Patches;
10use vortex_array::validity::Validity;
11use vortex_fastlanes::bitpack_compress::bitpack_encode_unchecked;
12
13mod array;
14mod compute;
15mod kernel;
16mod ops;
17mod rules;
18mod slice;
19
20use std::ops::Shl;
21use std::ops::Shr;
22
23use itertools::Itertools;
24use num_traits::Float;
25use num_traits::One;
26use num_traits::PrimInt;
27use rustc_hash::FxBuildHasher;
28use vortex_array::ArrayView;
29use vortex_array::arrays::Primitive;
30use vortex_array::arrays::PrimitiveArray;
31use vortex_array::dtype::DType;
32use vortex_array::dtype::NativePType;
33use vortex_array::match_each_integer_ptype;
34use vortex_buffer::Buffer;
35use vortex_buffer::BufferMut;
36use vortex_error::VortexExpect;
37use vortex_error::VortexResult;
38use vortex_error::vortex_panic;
39use vortex_utils::aliases::hash_map::HashMap;
40
41use crate::match_each_alp_float_ptype;
42
43macro_rules! bit_width {
44    ($value:expr) => {
45        if $value == 0 {
46            1
47        } else {
48            $value.ilog2().wrapping_add(1) as usize
49        }
50    };
51}
52
53/// Max number of bits to cut from the MSB section of each float.
54const CUT_LIMIT: usize = 16;
55
56const MAX_DICT_SIZE: u8 = 8;
57
58mod private {
59    pub trait Sealed {}
60
61    impl Sealed for f32 {}
62    impl Sealed for f64 {}
63}
64
65/// Main trait for ALP-RD encodable floating point numbers.
66///
67/// Like the paper, we limit this to the IEEE7 754 single-precision (`f32`) and double-precision
68/// (`f64`) floating point types.
69pub trait ALPRDFloat: private::Sealed + Float + Copy + NativePType {
70    /// The unsigned integer type with the same bit-width as the floating-point type.
71    type UINT: NativePType + PrimInt + One + Copy;
72
73    /// Number of bits the value occupies in registers.
74    const BITS: usize = size_of::<Self>() * 8;
75
76    /// Bit-wise transmute from the unsigned integer type to the floating-point type.
77    fn from_bits(bits: Self::UINT) -> Self;
78
79    /// Bit-wise transmute into the unsigned integer type.
80    fn to_bits(value: Self) -> Self::UINT;
81
82    /// Truncating conversion from the unsigned integer type to `u16`.
83    fn to_u16(bits: Self::UINT) -> u16;
84
85    /// Type-widening conversion from `u16` to the unsigned integer type.
86    fn from_u16(value: u16) -> Self::UINT;
87}
88
89impl ALPRDFloat for f64 {
90    type UINT = u64;
91
92    fn from_bits(bits: Self::UINT) -> Self {
93        f64::from_bits(bits)
94    }
95
96    fn to_bits(value: Self) -> Self::UINT {
97        value.to_bits()
98    }
99
100    fn to_u16(bits: Self::UINT) -> u16 {
101        bits as u16
102    }
103
104    fn from_u16(value: u16) -> Self::UINT {
105        value as u64
106    }
107}
108
109impl ALPRDFloat for f32 {
110    type UINT = u32;
111
112    fn from_bits(bits: Self::UINT) -> Self {
113        f32::from_bits(bits)
114    }
115
116    fn to_bits(value: Self) -> Self::UINT {
117        value.to_bits()
118    }
119
120    fn to_u16(bits: Self::UINT) -> u16 {
121        bits as u16
122    }
123
124    fn from_u16(value: u16) -> Self::UINT {
125        value as u32
126    }
127}
128
129/// Encoder for ALP-RD ("real doubles") values.
130///
131/// The encoder calculates its parameters from a single sample of floating-point values,
132/// and then can be applied to many vectors.
133///
134/// ALP-RD uses the algorithm outlined in Section 3.4 of the paper. The crux of it is that the front
135/// (most significant) bits of many double vectors tend to be  the same, i.e. most doubles in a
136/// vector often use the same exponent and front bits. Compression proceeds by finding the best
137/// prefix of up to 16 bits that can be collapsed into a dictionary of
138/// up to 8 elements. Each double can then be broken into the front/left `L` bits, which neatly
139/// bit-packs down to 1-3 bits per element (depending on the actual dictionary size).
140/// The remaining `R` bits naturally bit-pack.
141///
142/// In the ideal case, this scheme allows us to store a sequence of doubles in 49 bits-per-value.
143///
144/// Our implementation draws on the MIT-licensed [C++ implementation] provided by the original authors.
145///
146/// [C++ implementation]: https://github.com/cwida/ALP/blob/main/include/alp/rd.hpp
147pub struct RDEncoder {
148    right_bit_width: u8,
149    codes: Vec<u16>,
150}
151
152impl RDEncoder {
153    /// Build a new encoder from a sample of doubles.
154    pub fn new<T>(sample: &[T]) -> Self
155    where
156        T: ALPRDFloat + NativePType,
157        T::UINT: NativePType,
158    {
159        let dictionary = find_best_dictionary::<T>(sample);
160
161        let mut codes = vec![0; dictionary.dictionary.len()];
162        dictionary.dictionary.into_iter().for_each(|(bits, code)| {
163            // write the reverse mapping into the codes vector.
164            codes[code as usize] = bits
165        });
166
167        Self {
168            right_bit_width: dictionary.right_bit_width,
169            codes,
170        }
171    }
172
173    /// Build a new encoder from known parameters.
174    pub fn from_parts(right_bit_width: u8, codes: Vec<u16>) -> Self {
175        Self {
176            right_bit_width,
177            codes,
178        }
179    }
180
181    /// Encode a set of floating point values with ALP-RD.
182    ///
183    /// Each value will be split into a left and right component, which are compressed individually.
184    // TODO(joe): make fallible
185    pub fn encode(&self, array: ArrayView<'_, Primitive>) -> ALPRDArray {
186        match_each_alp_float_ptype!(array.ptype(), |P| { self.encode_generic::<P>(array) })
187    }
188
189    fn encode_generic<T>(&self, array: ArrayView<'_, Primitive>) -> ALPRDArray
190    where
191        T: ALPRDFloat + NativePType,
192        T::UINT: NativePType,
193    {
194        assert!(
195            !self.codes.is_empty(),
196            "codes lookup table must be populated before RD encoding"
197        );
198
199        let doubles = array.as_slice::<T>();
200
201        let mut left_parts: BufferMut<u16> = BufferMut::with_capacity(doubles.len());
202        let mut right_parts: BufferMut<T::UINT> = BufferMut::with_capacity(doubles.len());
203        let mut exceptions_pos: BufferMut<u64> = BufferMut::with_capacity(doubles.len() / 4);
204        let mut exceptions: BufferMut<u16> = BufferMut::with_capacity(doubles.len() / 4);
205
206        // mask for right-parts
207        let right_mask = T::UINT::one().shl(self.right_bit_width as _) - T::UINT::one();
208        let max_code = self.codes.len() - 1;
209        let left_bit_width = bit_width!(max_code);
210
211        for v in doubles.iter().copied() {
212            right_parts.push(T::to_bits(v) & right_mask);
213            left_parts.push(<T as ALPRDFloat>::to_u16(
214                T::to_bits(v).shr(self.right_bit_width as _),
215            ));
216        }
217
218        // dict-encode the left-parts, keeping track of exceptions
219        for (idx, left) in left_parts.iter_mut().enumerate() {
220            // TODO: revisit if we need to change the branch order for perf.
221            if let Some(code) = self.codes.iter().position(|v| *v == *left) {
222                *left = code as u16;
223            } else {
224                exceptions.push(*left);
225                exceptions_pos.push(idx as _);
226
227                *left = 0u16;
228            }
229        }
230
231        // Bit-pack down the encoded left-parts array that have been dictionary encoded.
232        let primitive_left = PrimitiveArray::new(
233            left_parts,
234            array
235                .validity()
236                .vortex_expect("ALP RD validity should be derivable"),
237        );
238        // SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
239        let packed_left = unsafe {
240            bitpack_encode_unchecked(primitive_left, left_bit_width as _)
241                .vortex_expect("bitpack_encode_unchecked should succeed for left parts")
242                .into_array()
243        };
244
245        let primitive_right = PrimitiveArray::new(right_parts, Validity::NonNullable);
246        // SAFETY: by construction, all values in right_parts are right_bit_width + leading zeros.
247        let packed_right = unsafe {
248            bitpack_encode_unchecked(primitive_right, self.right_bit_width as _)
249                .vortex_expect("bitpack_encode_unchecked should succeed for right parts")
250                .into_array()
251        };
252
253        // Bit-pack the dict-encoded left-parts
254        // Bit-pack the right-parts
255        // Patches for exceptions.
256        let exceptions = (!exceptions_pos.is_empty()).then(|| {
257            let max_exc_pos = exceptions_pos.last().copied().unwrap_or_default();
258            let bw = bit_width!(max_exc_pos) as u8;
259
260            let exc_pos_array = PrimitiveArray::new(exceptions_pos, Validity::NonNullable);
261            // SAFETY: We calculate bw such that it is wide enough to hold the largest position index.
262            let packed_pos = unsafe {
263                bitpack_encode_unchecked(exc_pos_array, bw)
264                    .vortex_expect(
265                        "bitpack_encode_unchecked should succeed for exception positions",
266                    )
267                    .into_array()
268            };
269
270            Patches::new(
271                doubles.len(),
272                0,
273                packed_pos,
274                exceptions.into_array(),
275                // TODO(0ax1): handle chunk offsets
276                None,
277            )
278            .vortex_expect("Patches construction in encode")
279        });
280
281        ALPRD::try_new(
282            DType::Primitive(T::PTYPE, packed_left.dtype().nullability()),
283            packed_left,
284            Buffer::<u16>::copy_from(&self.codes),
285            packed_right,
286            self.right_bit_width,
287            exceptions,
288        )
289        .vortex_expect("ALPRDArray construction in encode")
290    }
291}
292
293/// Decode ALP-RD encoded values back into their original floating point format.
294///
295/// # Panics
296///
297/// Panics if `left_parts` and `right_parts` differ in length.
298pub fn alp_rd_decode<T: ALPRDFloat>(
299    mut left_parts: BufferMut<u16>,
300    left_parts_dict: &[u16],
301    right_bit_width: u8,
302    right_parts: BufferMut<T::UINT>,
303    left_parts_patches: Option<Patches>,
304    ctx: &mut ExecutionCtx,
305) -> VortexResult<Buffer<T>> {
306    if left_parts.len() != right_parts.len() {
307        vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len");
308    }
309
310    let shift = right_bit_width as usize;
311
312    if let Some(patches) = left_parts_patches {
313        // Patched path: some left-part codes map to exception values that live outside
314        // the dictionary. We must dictionary-decode first, then overwrite the exceptions,
315        // before we can combine with right-parts.
316
317        // Dictionary-decode every code in-place (code → actual left bit-pattern).
318        for code in left_parts.iter_mut() {
319            *code = left_parts_dict[*code as usize];
320        }
321
322        // Overwrite exception positions with their true left bit-patterns.
323        let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
324        let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
325        alp_rd_apply_patches(&mut left_parts, &indices, &patch_values, patches.offset());
326
327        // Reconstruct floats by shifting each decoded left value into the MSBs
328        // and OR-ing with the corresponding right value.
329        alp_rd_combine_inplace::<T>(
330            right_parts,
331            |right, &left| {
332                *right = (<T as ALPRDFloat>::from_u16(left) << shift) | *right;
333            },
334            left_parts.as_ref(),
335        )
336    } else {
337        // Non-patched fast path: every code maps through the dictionary, so we can
338        // pre-shift the entire dictionary once and reduce the per-element hot loop to
339        // a single table lookup + OR.
340        let mut shifted_dict = [T::UINT::default(); MAX_DICT_SIZE as usize];
341        for (i, &entry) in left_parts_dict.iter().enumerate() {
342            shifted_dict[i] = <T as ALPRDFloat>::from_u16(entry) << shift;
343        }
344
345        // Each element: look up the pre-shifted left value by code, OR with right-parts.
346        alp_rd_combine_inplace::<T>(
347            right_parts,
348            |right, &code| {
349                // SAFETY: codes are bounded by dict size (< left_parts_dict.len() <= MAX_DICT_SIZE).
350                *right = unsafe { *shifted_dict.get_unchecked(code as usize) } | *right;
351            },
352            left_parts.as_ref(),
353        )
354    }
355}
356
357/// Apply patches to the decoded left-parts values.
358fn alp_rd_apply_patches(
359    values: &mut BufferMut<u16>,
360    indices: &PrimitiveArray,
361    patch_values: &PrimitiveArray,
362    offset: usize,
363) {
364    match_each_integer_ptype!(indices.ptype(), |T| {
365        indices
366            .as_slice::<T>()
367            .iter()
368            .copied()
369            .map(|idx| idx - offset as T)
370            .zip(patch_values.as_slice::<u16>().iter())
371            .for_each(|(idx, v)| values[idx as usize] = *v);
372    })
373}
374
375/// Zip `right_parts` with `left_data`, apply `combine_fn` per element, then reinterpret the
376/// buffer from `T::UINT` to `T` (same bit-width: u32↔f32, u64↔f64).
377fn alp_rd_combine_inplace<T: ALPRDFloat>(
378    mut right_parts: BufferMut<T::UINT>,
379    combine_fn: impl Fn(&mut T::UINT, &u16),
380    left_data: &[u16],
381) -> VortexResult<Buffer<T>> {
382    for (right, left) in right_parts.as_mut_slice().iter_mut().zip(left_data.iter()) {
383        combine_fn(right, left);
384    }
385    // SAFETY: all bit patterns of T::UINT are valid T (u32↔f32 or u64↔f64).
386    Ok(unsafe { right_parts.transmute::<T>() }.freeze())
387}
388/// Find the best "cut point" for a set of floating point values such that we can
389/// cast them all to the relevant value instead.
390fn find_best_dictionary<T: ALPRDFloat>(samples: &[T]) -> ALPRDDictionary {
391    let mut best_est_size = f64::MAX;
392    let mut best_dict = ALPRDDictionary::default();
393
394    for p in 1..=16 {
395        let candidate_right_bw = (T::BITS - p) as u8;
396        let (dictionary, exception_count) =
397            build_left_parts_dictionary::<T>(samples, candidate_right_bw, MAX_DICT_SIZE);
398        let estimated_size = estimate_compression_size(
399            dictionary.right_bit_width,
400            dictionary.left_bit_width,
401            exception_count,
402            samples.len(),
403        );
404        if estimated_size < best_est_size {
405            best_est_size = estimated_size;
406            best_dict = dictionary;
407        }
408    }
409
410    best_dict
411}
412
413/// Build dictionary of the leftmost bits.
414fn build_left_parts_dictionary<T: ALPRDFloat>(
415    samples: &[T],
416    right_bw: u8,
417    max_dict_size: u8,
418) -> (ALPRDDictionary, usize) {
419    assert!(
420        right_bw >= (T::BITS - CUT_LIMIT) as _,
421        "left-parts must be <= 16 bits"
422    );
423
424    // Count the number of occurrences of each left bit pattern
425    let mut counts = HashMap::new();
426    samples
427        .iter()
428        .copied()
429        .map(|v| <T as ALPRDFloat>::to_u16(T::to_bits(v).shr(right_bw as _)))
430        .for_each(|item| *counts.entry(item).or_default() += 1);
431
432    // Sorted counts: sort by negative count so that heavy hitters sort first.
433    let mut sorted_bit_counts: Vec<(u16, usize)> = counts.into_iter().collect_vec();
434    sorted_bit_counts.sort_by_key(|(_, count)| count.wrapping_neg());
435
436    // Assign the most-frequently occurring left-bits as dictionary codes, up to `dict_size`...
437    let mut dictionary = HashMap::with_capacity_and_hasher(max_dict_size as _, FxBuildHasher);
438    let mut code = 0u16;
439    while code < (max_dict_size as _) && (code as usize) < sorted_bit_counts.len() {
440        let (bits, _) = sorted_bit_counts[code as usize];
441        dictionary.insert(bits, code);
442        code += 1;
443    }
444
445    // ...and the rest are exceptions.
446    let exception_count: usize = sorted_bit_counts
447        .iter()
448        .skip(code as _)
449        .map(|(_, count)| *count)
450        .sum();
451
452    // Left bit-width is determined based on the actual dictionary size.
453    let max_code = dictionary.len() - 1;
454    let left_bw = bit_width!(max_code) as u8;
455
456    (
457        ALPRDDictionary {
458            dictionary,
459            right_bit_width: right_bw,
460            left_bit_width: left_bw,
461        },
462        exception_count,
463    )
464}
465
466/// Estimate the bits-per-value when using these compression settings.
467fn estimate_compression_size(
468    right_bw: u8,
469    left_bw: u8,
470    exception_count: usize,
471    sample_n: usize,
472) -> f64 {
473    const EXC_POSITION_SIZE: usize = 16; // two bytes for exception position.
474    const EXC_SIZE: usize = 16; // two bytes for each exception (up to 16 front bits).
475
476    let exceptions_size = exception_count * (EXC_POSITION_SIZE + EXC_SIZE);
477    (right_bw as f64) + (left_bw as f64) + ((exceptions_size as f64) / (sample_n as f64))
478}
479
480/// The ALP-RD dictionary, encoding the "left parts" and their dictionary encoding.
481#[derive(Debug, Default)]
482struct ALPRDDictionary {
483    /// Items in the dictionary are bit patterns, along with their 16-bit encoding.
484    dictionary: HashMap<u16, u16, FxBuildHasher>,
485    /// The (compressed) left bit width. This is after bit-packing the dictionary codes.
486    left_bit_width: u8,
487    /// The right bit width. This is the bit-packed width of each of the "real double" values.
488    right_bit_width: u8,
489}