Skip to main content

vortex_alp/alp_rd/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#![expect(clippy::cast_possible_truncation)]
5
6pub use array::*;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::patches::Patches;
10use vortex_array::validity::Validity;
11use vortex_fastlanes::bitpack_compress::bitpack_encode_unchecked;
12
13mod array;
14mod compute;
15mod kernel;
16mod ops;
17mod rules;
18mod slice;
19
20use std::ops::Shl;
21use std::ops::Shr;
22
23use itertools::Itertools;
24use num_traits::Float;
25use num_traits::One;
26use num_traits::PrimInt;
27use rustc_hash::FxBuildHasher;
28use vortex_array::ArrayView;
29use vortex_array::arrays::Primitive;
30use vortex_array::arrays::PrimitiveArray;
31use vortex_array::dtype::DType;
32use vortex_array::dtype::NativePType;
33use vortex_array::match_each_integer_ptype;
34use vortex_buffer::Buffer;
35use vortex_buffer::BufferMut;
36use vortex_error::VortexExpect;
37use vortex_error::VortexResult;
38use vortex_error::vortex_panic;
39use vortex_utils::aliases::hash_map::HashMap;
40
41use crate::match_each_alp_float_ptype;
42
43macro_rules! bit_width {
44    ($value:expr) => {
45        if $value == 0 {
46            1
47        } else {
48            $value.ilog2().wrapping_add(1) as usize
49        }
50    };
51}
52
53/// Max number of bits to cut from the MSB section of each float.
54const CUT_LIMIT: usize = 16;
55
56const MAX_DICT_SIZE: u8 = 8;
57
58mod private {
59    pub trait Sealed {}
60
61    impl Sealed for f32 {}
62    impl Sealed for f64 {}
63}
64
65/// Main trait for ALP-RD encodable floating point numbers.
66///
67/// Like the paper, we limit this to the IEEE7 754 single-precision (`f32`) and double-precision
68/// (`f64`) floating point types.
69pub trait ALPRDFloat: private::Sealed + Float + Copy + NativePType {
70    /// The unsigned integer type with the same bit-width as the floating-point type.
71    type UINT: NativePType + PrimInt + One + Copy;
72
73    /// Number of bits the value occupies in registers.
74    const BITS: usize = size_of::<Self>() * 8;
75
76    /// Bit-wise transmute from the unsigned integer type to the floating-point type.
77    fn from_bits(bits: Self::UINT) -> Self;
78
79    /// Bit-wise transmute into the unsigned integer type.
80    fn to_bits(value: Self) -> Self::UINT;
81
82    /// Truncating conversion from the unsigned integer type to `u16`.
83    fn to_u16(bits: Self::UINT) -> u16;
84
85    /// Type-widening conversion from `u16` to the unsigned integer type.
86    fn from_u16(value: u16) -> Self::UINT;
87}
88
89impl ALPRDFloat for f64 {
90    type UINT = u64;
91
92    fn from_bits(bits: Self::UINT) -> Self {
93        f64::from_bits(bits)
94    }
95
96    fn to_bits(value: Self) -> Self::UINT {
97        value.to_bits()
98    }
99
100    fn to_u16(bits: Self::UINT) -> u16 {
101        bits as u16
102    }
103
104    fn from_u16(value: u16) -> Self::UINT {
105        value as u64
106    }
107}
108
109impl ALPRDFloat for f32 {
110    type UINT = u32;
111
112    fn from_bits(bits: Self::UINT) -> Self {
113        f32::from_bits(bits)
114    }
115
116    fn to_bits(value: Self) -> Self::UINT {
117        value.to_bits()
118    }
119
120    fn to_u16(bits: Self::UINT) -> u16 {
121        bits as u16
122    }
123
124    fn from_u16(value: u16) -> Self::UINT {
125        value as u32
126    }
127}
128
129/// Encoder for ALP-RD ("real doubles") values.
130///
131/// The encoder calculates its parameters from a single sample of floating-point values,
132/// and then can be applied to many vectors.
133///
134/// ALP-RD uses the algorithm outlined in Section 3.4 of the paper. The crux of it is that the front
135/// (most significant) bits of many double vectors tend to be  the same, i.e. most doubles in a
136/// vector often use the same exponent and front bits. Compression proceeds by finding the best
137/// prefix of up to 16 bits that can be collapsed into a dictionary of
138/// up to 8 elements. Each double can then be broken into the front/left `L` bits, which neatly
139/// bit-packs down to 1-3 bits per element (depending on the actual dictionary size).
140/// The remaining `R` bits naturally bit-pack.
141///
142/// In the ideal case, this scheme allows us to store a sequence of doubles in 49 bits-per-value.
143///
144/// Our implementation draws on the MIT-licensed [C++ implementation] provided by the original authors.
145///
146/// [C++ implementation]: https://github.com/cwida/ALP/blob/main/include/alp/rd.hpp
147pub struct RDEncoder {
148    right_bit_width: u8,
149    codes: Vec<u16>,
150}
151
152impl RDEncoder {
153    /// Build a new encoder from a sample of doubles.
154    pub fn new<T>(sample: &[T]) -> Self
155    where
156        T: ALPRDFloat + NativePType,
157        T::UINT: NativePType,
158    {
159        let dictionary = find_best_dictionary::<T>(sample);
160
161        let mut codes = vec![0; dictionary.dictionary.len()];
162        dictionary.dictionary.into_iter().for_each(|(bits, code)| {
163            // write the reverse mapping into the codes vector.
164            codes[code as usize] = bits
165        });
166
167        Self {
168            right_bit_width: dictionary.right_bit_width,
169            codes,
170        }
171    }
172
173    /// Build a new encoder from known parameters.
174    pub fn from_parts(right_bit_width: u8, codes: Vec<u16>) -> Self {
175        Self {
176            right_bit_width,
177            codes,
178        }
179    }
180
181    /// Encode a set of floating point values with ALP-RD.
182    ///
183    /// Each value will be split into a left and right component, which are compressed individually.
184    // TODO(joe): make fallible
185    pub fn encode(&self, array: ArrayView<'_, Primitive>, ctx: &mut ExecutionCtx) -> ALPRDArray {
186        match_each_alp_float_ptype!(array.ptype(), |P| { self.encode_generic::<P>(array, ctx) })
187    }
188
189    fn encode_generic<T>(
190        &self,
191        array: ArrayView<'_, Primitive>,
192        ctx: &mut ExecutionCtx,
193    ) -> ALPRDArray
194    where
195        T: ALPRDFloat + NativePType,
196        T::UINT: NativePType,
197    {
198        assert!(
199            !self.codes.is_empty(),
200            "codes lookup table must be populated before RD encoding"
201        );
202
203        let doubles = array.as_slice::<T>();
204
205        let mut left_parts: BufferMut<u16> = BufferMut::with_capacity(doubles.len());
206        let mut right_parts: BufferMut<T::UINT> = BufferMut::with_capacity(doubles.len());
207        let mut exceptions_pos: BufferMut<u64> = BufferMut::with_capacity(doubles.len() / 4);
208        let mut exceptions: BufferMut<u16> = BufferMut::with_capacity(doubles.len() / 4);
209
210        // mask for right-parts
211        let right_mask = T::UINT::one().shl(self.right_bit_width as _) - T::UINT::one();
212        let max_code = self.codes.len() - 1;
213        let left_bit_width = bit_width!(max_code);
214
215        for v in doubles.iter().copied() {
216            right_parts.push(T::to_bits(v) & right_mask);
217            left_parts.push(<T as ALPRDFloat>::to_u16(
218                T::to_bits(v).shr(self.right_bit_width as _),
219            ));
220        }
221
222        // dict-encode the left-parts, keeping track of exceptions
223        for (idx, left) in left_parts.iter_mut().enumerate() {
224            // TODO: revisit if we need to change the branch order for perf.
225            if let Some(code) = self.codes.iter().position(|v| *v == *left) {
226                *left = code as u16;
227            } else {
228                exceptions.push(*left);
229                exceptions_pos.push(idx as _);
230
231                *left = 0u16;
232            }
233        }
234
235        // Bit-pack down the encoded left-parts array that have been dictionary encoded.
236        let primitive_left = PrimitiveArray::new(
237            left_parts,
238            array
239                .validity()
240                .vortex_expect("ALP RD validity should be derivable"),
241        );
242        // SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
243        let packed_left = unsafe {
244            bitpack_encode_unchecked(primitive_left, left_bit_width as _)
245                .vortex_expect("bitpack_encode_unchecked should succeed for left parts")
246                .into_array()
247        };
248
249        let primitive_right = PrimitiveArray::new(right_parts, Validity::NonNullable);
250        // SAFETY: by construction, all values in right_parts are right_bit_width + leading zeros.
251        let packed_right = unsafe {
252            bitpack_encode_unchecked(primitive_right, self.right_bit_width as _)
253                .vortex_expect("bitpack_encode_unchecked should succeed for right parts")
254                .into_array()
255        };
256
257        // Bit-pack the dict-encoded left-parts
258        // Bit-pack the right-parts
259        // Patches for exceptions.
260        let exceptions = (!exceptions_pos.is_empty()).then(|| {
261            let max_exc_pos = exceptions_pos.last().copied().unwrap_or_default();
262            let bw = bit_width!(max_exc_pos) as u8;
263
264            let exc_pos_array = PrimitiveArray::new(exceptions_pos, Validity::NonNullable);
265            // SAFETY: We calculate bw such that it is wide enough to hold the largest position index.
266            let packed_pos = unsafe {
267                bitpack_encode_unchecked(exc_pos_array, bw)
268                    .vortex_expect(
269                        "bitpack_encode_unchecked should succeed for exception positions",
270                    )
271                    .into_array()
272            };
273
274            Patches::new(
275                doubles.len(),
276                0,
277                packed_pos,
278                exceptions.into_array(),
279                // TODO(0ax1): handle chunk offsets
280                None,
281            )
282            .vortex_expect("Patches construction in encode")
283        });
284
285        ALPRD::try_new(
286            DType::Primitive(T::PTYPE, packed_left.dtype().nullability()),
287            packed_left,
288            Buffer::<u16>::copy_from(&self.codes),
289            packed_right,
290            self.right_bit_width,
291            exceptions,
292            ctx,
293        )
294        .vortex_expect("ALPRDArray construction in encode")
295    }
296}
297
298/// Decode ALP-RD encoded values back into their original floating point format.
299///
300/// # Panics
301///
302/// Panics if `left_parts` and `right_parts` differ in length.
303pub fn alp_rd_decode<T: ALPRDFloat>(
304    mut left_parts: BufferMut<u16>,
305    left_parts_dict: &[u16],
306    right_bit_width: u8,
307    right_parts: BufferMut<T::UINT>,
308    left_parts_patches: Option<Patches>,
309    ctx: &mut ExecutionCtx,
310) -> VortexResult<Buffer<T>> {
311    if left_parts.len() != right_parts.len() {
312        vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len");
313    }
314
315    let shift = right_bit_width as usize;
316
317    if let Some(patches) = left_parts_patches {
318        // Patched path: some left-part codes map to exception values that live outside
319        // the dictionary. We must dictionary-decode first, then overwrite the exceptions,
320        // before we can combine with right-parts.
321
322        // Dictionary-decode every code in-place (code → actual left bit-pattern).
323        for code in left_parts.iter_mut() {
324            *code = left_parts_dict[*code as usize];
325        }
326
327        // Overwrite exception positions with their true left bit-patterns.
328        let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
329        let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
330        alp_rd_apply_patches(&mut left_parts, &indices, &patch_values, patches.offset());
331
332        // Reconstruct floats by shifting each decoded left value into the MSBs
333        // and OR-ing with the corresponding right value.
334        alp_rd_combine_inplace::<T>(
335            right_parts,
336            |right, &left| {
337                *right = (<T as ALPRDFloat>::from_u16(left) << shift) | *right;
338            },
339            left_parts.as_ref(),
340        )
341    } else {
342        // Non-patched fast path: every code maps through the dictionary, so we can
343        // pre-shift the entire dictionary once and reduce the per-element hot loop to
344        // a single table lookup + OR.
345        let mut shifted_dict = [T::UINT::default(); MAX_DICT_SIZE as usize];
346        for (i, &entry) in left_parts_dict.iter().enumerate() {
347            shifted_dict[i] = <T as ALPRDFloat>::from_u16(entry) << shift;
348        }
349
350        // Each element: look up the pre-shifted left value by code, OR with right-parts.
351        alp_rd_combine_inplace::<T>(
352            right_parts,
353            |right, &code| {
354                // SAFETY: codes are bounded by dict size (< left_parts_dict.len() <= MAX_DICT_SIZE).
355                *right = unsafe { *shifted_dict.get_unchecked(code as usize) } | *right;
356            },
357            left_parts.as_ref(),
358        )
359    }
360}
361
362/// Apply patches to the decoded left-parts values.
363fn alp_rd_apply_patches(
364    values: &mut BufferMut<u16>,
365    indices: &PrimitiveArray,
366    patch_values: &PrimitiveArray,
367    offset: usize,
368) {
369    match_each_integer_ptype!(indices.ptype(), |T| {
370        indices
371            .as_slice::<T>()
372            .iter()
373            .copied()
374            .map(|idx| idx - offset as T)
375            .zip(patch_values.as_slice::<u16>().iter())
376            .for_each(|(idx, v)| values[idx as usize] = *v);
377    })
378}
379
380/// Zip `right_parts` with `left_data`, apply `combine_fn` per element, then reinterpret the
381/// buffer from `T::UINT` to `T` (same bit-width: u32↔f32, u64↔f64).
382fn alp_rd_combine_inplace<T: ALPRDFloat>(
383    mut right_parts: BufferMut<T::UINT>,
384    combine_fn: impl Fn(&mut T::UINT, &u16),
385    left_data: &[u16],
386) -> VortexResult<Buffer<T>> {
387    for (right, left) in right_parts.as_mut_slice().iter_mut().zip(left_data.iter()) {
388        combine_fn(right, left);
389    }
390    // SAFETY: all bit patterns of T::UINT are valid T (u32↔f32 or u64↔f64).
391    Ok(unsafe { right_parts.transmute::<T>() }.freeze())
392}
393/// Find the best "cut point" for a set of floating point values such that we can
394/// cast them all to the relevant value instead.
395fn find_best_dictionary<T: ALPRDFloat>(samples: &[T]) -> ALPRDDictionary {
396    let mut best_est_size = f64::MAX;
397    let mut best_dict = ALPRDDictionary::default();
398
399    for p in 1..=16 {
400        let candidate_right_bw = (T::BITS - p) as u8;
401        let (dictionary, exception_count) =
402            build_left_parts_dictionary::<T>(samples, candidate_right_bw, MAX_DICT_SIZE);
403        let estimated_size = estimate_compression_size(
404            dictionary.right_bit_width,
405            dictionary.left_bit_width,
406            exception_count,
407            samples.len(),
408        );
409        if estimated_size < best_est_size {
410            best_est_size = estimated_size;
411            best_dict = dictionary;
412        }
413    }
414
415    best_dict
416}
417
418/// Build dictionary of the leftmost bits.
419fn build_left_parts_dictionary<T: ALPRDFloat>(
420    samples: &[T],
421    right_bw: u8,
422    max_dict_size: u8,
423) -> (ALPRDDictionary, usize) {
424    assert!(
425        right_bw >= (T::BITS - CUT_LIMIT) as _,
426        "left-parts must be <= 16 bits"
427    );
428
429    // Count the number of occurrences of each left bit pattern
430    let mut counts = HashMap::new();
431    samples
432        .iter()
433        .copied()
434        .map(|v| <T as ALPRDFloat>::to_u16(T::to_bits(v).shr(right_bw as _)))
435        .for_each(|item| *counts.entry(item).or_default() += 1);
436
437    // Sorted counts: sort by negative count so that heavy hitters sort first.
438    let mut sorted_bit_counts: Vec<(u16, usize)> = counts.into_iter().collect_vec();
439    sorted_bit_counts.sort_by_key(|(_, count)| count.wrapping_neg());
440
441    // Assign the most-frequently occurring left-bits as dictionary codes, up to `dict_size`...
442    let mut dictionary = HashMap::with_capacity_and_hasher(max_dict_size as _, FxBuildHasher);
443    let mut code = 0u16;
444    while code < (max_dict_size as _) && (code as usize) < sorted_bit_counts.len() {
445        let (bits, _) = sorted_bit_counts[code as usize];
446        dictionary.insert(bits, code);
447        code += 1;
448    }
449
450    // ...and the rest are exceptions.
451    let exception_count: usize = sorted_bit_counts
452        .iter()
453        .skip(code as _)
454        .map(|(_, count)| *count)
455        .sum();
456
457    // Left bit-width is determined based on the actual dictionary size.
458    let max_code = dictionary.len() - 1;
459    let left_bw = bit_width!(max_code) as u8;
460
461    (
462        ALPRDDictionary {
463            dictionary,
464            right_bit_width: right_bw,
465            left_bit_width: left_bw,
466        },
467        exception_count,
468    )
469}
470
471/// Estimate the bits-per-value when using these compression settings.
472fn estimate_compression_size(
473    right_bw: u8,
474    left_bw: u8,
475    exception_count: usize,
476    sample_n: usize,
477) -> f64 {
478    const EXC_POSITION_SIZE: usize = 16; // two bytes for exception position.
479    const EXC_SIZE: usize = 16; // two bytes for each exception (up to 16 front bits).
480
481    let exceptions_size = exception_count * (EXC_POSITION_SIZE + EXC_SIZE);
482    (right_bw as f64) + (left_bw as f64) + ((exceptions_size as f64) / (sample_n as f64))
483}
484
485/// The ALP-RD dictionary, encoding the "left parts" and their dictionary encoding.
486#[derive(Debug, Default)]
487struct ALPRDDictionary {
488    /// Items in the dictionary are bit patterns, along with their 16-bit encoding.
489    dictionary: HashMap<u16, u16, FxBuildHasher>,
490    /// The (compressed) left bit width. This is after bit-packing the dictionary codes.
491    left_bit_width: u8,
492    /// The right bit width. This is the bit-packed width of each of the "real double" values.
493    right_bit_width: u8,
494}