Skip to main content

vortex_alp/alp_rd/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#![allow(clippy::cast_possible_truncation)]
5
6pub use array::*;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::patches::Patches;
10use vortex_array::validity::Validity;
11use vortex_fastlanes::bitpack_compress::bitpack_encode_unchecked;
12
13mod array;
14mod compute;
15mod kernel;
16mod ops;
17mod rules;
18mod slice;
19
20use std::ops::Shl;
21use std::ops::Shr;
22
23use itertools::Itertools;
24use num_traits::Float;
25use num_traits::One;
26use num_traits::PrimInt;
27use rustc_hash::FxBuildHasher;
28use vortex_array::Array;
29use vortex_array::arrays::PrimitiveArray;
30use vortex_array::dtype::DType;
31use vortex_array::dtype::NativePType;
32use vortex_array::match_each_integer_ptype;
33use vortex_array::vtable::ValidityHelper;
34use vortex_buffer::Buffer;
35use vortex_buffer::BufferMut;
36use vortex_error::VortexExpect;
37use vortex_error::VortexResult;
38use vortex_error::vortex_panic;
39use vortex_utils::aliases::hash_map::HashMap;
40
41use crate::match_each_alp_float_ptype;
42
43macro_rules! bit_width {
44    ($value:expr) => {
45        if $value == 0 {
46            1
47        } else {
48            $value.ilog2().wrapping_add(1) as usize
49        }
50    };
51}
52
53/// Max number of bits to cut from the MSB section of each float.
54const CUT_LIMIT: usize = 16;
55
56const MAX_DICT_SIZE: u8 = 8;
57
58mod private {
59    pub trait Sealed {}
60
61    impl Sealed for f32 {}
62    impl Sealed for f64 {}
63}
64
65/// Main trait for ALP-RD encodable floating point numbers.
66///
67/// Like the paper, we limit this to the IEEE7 754 single-precision (`f32`) and double-precision
68/// (`f64`) floating point types.
69pub trait ALPRDFloat: private::Sealed + Float + Copy + NativePType {
70    /// The unsigned integer type with the same bit-width as the floating-point type.
71    type UINT: NativePType + PrimInt + One + Copy;
72
73    /// Number of bits the value occupies in registers.
74    const BITS: usize = size_of::<Self>() * 8;
75
76    /// Bit-wise transmute from the unsigned integer type to the floating-point type.
77    fn from_bits(bits: Self::UINT) -> Self;
78
79    /// Bit-wise transmute into the unsigned integer type.
80    fn to_bits(value: Self) -> Self::UINT;
81
82    /// Truncating conversion from the unsigned integer type to `u16`.
83    fn to_u16(bits: Self::UINT) -> u16;
84
85    /// Type-widening conversion from `u16` to the unsigned integer type.
86    fn from_u16(value: u16) -> Self::UINT;
87}
88
89impl ALPRDFloat for f64 {
90    type UINT = u64;
91
92    fn from_bits(bits: Self::UINT) -> Self {
93        f64::from_bits(bits)
94    }
95
96    fn to_bits(value: Self) -> Self::UINT {
97        value.to_bits()
98    }
99
100    fn to_u16(bits: Self::UINT) -> u16 {
101        bits as u16
102    }
103
104    fn from_u16(value: u16) -> Self::UINT {
105        value as u64
106    }
107}
108
109impl ALPRDFloat for f32 {
110    type UINT = u32;
111
112    fn from_bits(bits: Self::UINT) -> Self {
113        f32::from_bits(bits)
114    }
115
116    fn to_bits(value: Self) -> Self::UINT {
117        value.to_bits()
118    }
119
120    fn to_u16(bits: Self::UINT) -> u16 {
121        bits as u16
122    }
123
124    fn from_u16(value: u16) -> Self::UINT {
125        value as u32
126    }
127}
128
129/// Encoder for ALP-RD ("real doubles") values.
130///
131/// The encoder calculates its parameters from a single sample of floating-point values,
132/// and then can be applied to many vectors.
133///
134/// ALP-RD uses the algorithm outlined in Section 3.4 of the paper. The crux of it is that the front
135/// (most significant) bits of many double vectors tend to be  the same, i.e. most doubles in a
136/// vector often use the same exponent and front bits. Compression proceeds by finding the best
137/// prefix of up to 16 bits that can be collapsed into a dictionary of
138/// up to 8 elements. Each double can then be broken into the front/left `L` bits, which neatly
139/// bit-packs down to 1-3 bits per element (depending on the actual dictionary size).
140/// The remaining `R` bits naturally bit-pack.
141///
142/// In the ideal case, this scheme allows us to store a sequence of doubles in 49 bits-per-value.
143///
144/// Our implementation draws on the MIT-licensed [C++ implementation] provided by the original authors.
145///
146/// [C++ implementation]: https://github.com/cwida/ALP/blob/main/include/alp/rd.hpp
147pub struct RDEncoder {
148    right_bit_width: u8,
149    codes: Vec<u16>,
150}
151
152impl RDEncoder {
153    /// Build a new encoder from a sample of doubles.
154    pub fn new<T>(sample: &[T]) -> Self
155    where
156        T: ALPRDFloat + NativePType,
157        T::UINT: NativePType,
158    {
159        let dictionary = find_best_dictionary::<T>(sample);
160
161        let mut codes = vec![0; dictionary.dictionary.len()];
162        dictionary.dictionary.into_iter().for_each(|(bits, code)| {
163            // write the reverse mapping into the codes vector.
164            codes[code as usize] = bits
165        });
166
167        Self {
168            right_bit_width: dictionary.right_bit_width,
169            codes,
170        }
171    }
172
173    /// Build a new encoder from known parameters.
174    pub fn from_parts(right_bit_width: u8, codes: Vec<u16>) -> Self {
175        Self {
176            right_bit_width,
177            codes,
178        }
179    }
180
181    /// Encode a set of floating point values with ALP-RD.
182    ///
183    /// Each value will be split into a left and right component, which are compressed individually.
184    // TODO(joe): make fallible
185    pub fn encode(&self, array: &PrimitiveArray) -> ALPRDArray {
186        match_each_alp_float_ptype!(array.ptype(), |P| { self.encode_generic::<P>(array) })
187    }
188
189    fn encode_generic<T>(&self, array: &PrimitiveArray) -> ALPRDArray
190    where
191        T: ALPRDFloat + NativePType,
192        T::UINT: NativePType,
193    {
194        assert!(
195            !self.codes.is_empty(),
196            "codes lookup table must be populated before RD encoding"
197        );
198
199        let doubles = array.as_slice::<T>();
200
201        let mut left_parts: BufferMut<u16> = BufferMut::with_capacity(doubles.len());
202        let mut right_parts: BufferMut<T::UINT> = BufferMut::with_capacity(doubles.len());
203        let mut exceptions_pos: BufferMut<u64> = BufferMut::with_capacity(doubles.len() / 4);
204        let mut exceptions: BufferMut<u16> = BufferMut::with_capacity(doubles.len() / 4);
205
206        // mask for right-parts
207        let right_mask = T::UINT::one().shl(self.right_bit_width as _) - T::UINT::one();
208        let max_code = self.codes.len() - 1;
209        let left_bit_width = bit_width!(max_code);
210
211        for v in doubles.iter().copied() {
212            right_parts.push(T::to_bits(v) & right_mask);
213            left_parts.push(<T as ALPRDFloat>::to_u16(
214                T::to_bits(v).shr(self.right_bit_width as _),
215            ));
216        }
217
218        // dict-encode the left-parts, keeping track of exceptions
219        for (idx, left) in left_parts.iter_mut().enumerate() {
220            // TODO: revisit if we need to change the branch order for perf.
221            if let Some(code) = self.codes.iter().position(|v| *v == *left) {
222                *left = code as u16;
223            } else {
224                exceptions.push(*left);
225                exceptions_pos.push(idx as _);
226
227                *left = 0u16;
228            }
229        }
230
231        // Bit-pack down the encoded left-parts array that have been dictionary encoded.
232        let primitive_left = PrimitiveArray::new(left_parts, array.validity().clone());
233        // SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
234        let packed_left = unsafe {
235            bitpack_encode_unchecked(primitive_left, left_bit_width as _)
236                .vortex_expect("bitpack_encode_unchecked should succeed for left parts")
237                .into_array()
238        };
239
240        let primitive_right = PrimitiveArray::new(right_parts, Validity::NonNullable);
241        // SAFETY: by construction, all values in right_parts are right_bit_width + leading zeros.
242        let packed_right = unsafe {
243            bitpack_encode_unchecked(primitive_right, self.right_bit_width as _)
244                .vortex_expect("bitpack_encode_unchecked should succeed for right parts")
245                .into_array()
246        };
247
248        // Bit-pack the dict-encoded left-parts
249        // Bit-pack the right-parts
250        // Patches for exceptions.
251        let exceptions = (!exceptions_pos.is_empty()).then(|| {
252            let max_exc_pos = exceptions_pos.last().copied().unwrap_or_default();
253            let bw = bit_width!(max_exc_pos) as u8;
254
255            let exc_pos_array = PrimitiveArray::new(exceptions_pos, Validity::NonNullable);
256            // SAFETY: We calculate bw such that it is wide enough to hold the largest position index.
257            let packed_pos = unsafe {
258                bitpack_encode_unchecked(exc_pos_array, bw)
259                    .vortex_expect(
260                        "bitpack_encode_unchecked should succeed for exception positions",
261                    )
262                    .into_array()
263            };
264
265            Patches::new(
266                doubles.len(),
267                0,
268                packed_pos,
269                exceptions.into_array(),
270                // TODO(0ax1): handle chunk offsets
271                None,
272            )
273            .vortex_expect("Patches construction in encode")
274        });
275
276        ALPRDArray::try_new(
277            DType::Primitive(T::PTYPE, packed_left.dtype().nullability()),
278            packed_left,
279            Buffer::<u16>::copy_from(&self.codes),
280            packed_right,
281            self.right_bit_width,
282            exceptions,
283        )
284        .vortex_expect("ALPRDArray construction in encode")
285    }
286}
287
288/// Decode a vector of ALP-RD encoded values back into their original floating point format.
289///
290/// # Panics
291///
292/// The function panics if the provided `left_parts` and `right_parts` differ in length.
293pub fn alp_rd_decode<T: ALPRDFloat>(
294    left_parts: Buffer<u16>,
295    left_parts_dict: &[u16],
296    right_bit_width: u8,
297    right_parts: BufferMut<T::UINT>,
298    left_parts_patches: Option<&Patches>,
299    ctx: &mut ExecutionCtx,
300) -> VortexResult<Buffer<T>> {
301    if left_parts.len() != right_parts.len() {
302        vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len");
303    }
304
305    // Decode the left-parts dictionary
306    let mut values = BufferMut::<u16>::from_iter(
307        left_parts
308            .iter()
309            .map(|code| left_parts_dict[*code as usize]),
310    );
311
312    // Apply any patches
313    if let Some(patches) = left_parts_patches {
314        let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
315        let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
316        alp_rd_apply_patches(&mut values, &indices, &patch_values, patches.offset());
317    }
318
319    // Shift the left-parts and add in the right-parts.
320    Ok(alp_rd_decode_core(
321        left_parts_dict,
322        right_bit_width,
323        right_parts,
324        values,
325    ))
326}
327
328/// Apply patches to the decoded left-parts values.
329fn alp_rd_apply_patches(
330    values: &mut BufferMut<u16>,
331    indices: &PrimitiveArray,
332    patch_values: &PrimitiveArray,
333    offset: usize,
334) {
335    match_each_integer_ptype!(indices.ptype(), |T| {
336        indices
337            .as_slice::<T>()
338            .iter()
339            .copied()
340            .map(|idx| idx - offset as T)
341            .zip(patch_values.as_slice::<u16>().iter())
342            .for_each(|(idx, v)| values[idx as usize] = *v);
343    })
344}
345
346/// Core decode logic shared between `alp_rd_decode` and `execute_alp_rd_decode`.
347fn alp_rd_decode_core<T: ALPRDFloat>(
348    _left_parts_dict: &[u16],
349    right_bit_width: u8,
350    right_parts: BufferMut<T::UINT>,
351    values: BufferMut<u16>,
352) -> Buffer<T> {
353    // Shift the left-parts and add in the right-parts.
354    let mut index = 0;
355    right_parts
356        .map_each_in_place(|right| {
357            let left = values[index];
358            index += 1;
359            let left = <T as ALPRDFloat>::from_u16(left);
360            T::from_bits((left << (right_bit_width as usize)) | right)
361        })
362        .freeze()
363}
364
365/// Find the best "cut point" for a set of floating point values such that we can
366/// cast them all to the relevant value instead.
367fn find_best_dictionary<T: ALPRDFloat>(samples: &[T]) -> ALPRDDictionary {
368    let mut best_est_size = f64::MAX;
369    let mut best_dict = ALPRDDictionary::default();
370
371    for p in 1..=16 {
372        let candidate_right_bw = (T::BITS - p) as u8;
373        let (dictionary, exception_count) =
374            build_left_parts_dictionary::<T>(samples, candidate_right_bw, MAX_DICT_SIZE);
375        let estimated_size = estimate_compression_size(
376            dictionary.right_bit_width,
377            dictionary.left_bit_width,
378            exception_count,
379            samples.len(),
380        );
381        if estimated_size < best_est_size {
382            best_est_size = estimated_size;
383            best_dict = dictionary;
384        }
385    }
386
387    best_dict
388}
389
390/// Build dictionary of the leftmost bits.
391fn build_left_parts_dictionary<T: ALPRDFloat>(
392    samples: &[T],
393    right_bw: u8,
394    max_dict_size: u8,
395) -> (ALPRDDictionary, usize) {
396    assert!(
397        right_bw >= (T::BITS - CUT_LIMIT) as _,
398        "left-parts must be <= 16 bits"
399    );
400
401    // Count the number of occurrences of each left bit pattern
402    let mut counts = HashMap::new();
403    samples
404        .iter()
405        .copied()
406        .map(|v| <T as ALPRDFloat>::to_u16(T::to_bits(v).shr(right_bw as _)))
407        .for_each(|item| *counts.entry(item).or_default() += 1);
408
409    // Sorted counts: sort by negative count so that heavy hitters sort first.
410    let mut sorted_bit_counts: Vec<(u16, usize)> = counts.into_iter().collect_vec();
411    sorted_bit_counts.sort_by_key(|(_, count)| count.wrapping_neg());
412
413    // Assign the most-frequently occurring left-bits as dictionary codes, up to `dict_size`...
414    let mut dictionary = HashMap::with_capacity_and_hasher(max_dict_size as _, FxBuildHasher);
415    let mut code = 0u16;
416    while code < (max_dict_size as _) && (code as usize) < sorted_bit_counts.len() {
417        let (bits, _) = sorted_bit_counts[code as usize];
418        dictionary.insert(bits, code);
419        code += 1;
420    }
421
422    // ...and the rest are exceptions.
423    let exception_count: usize = sorted_bit_counts
424        .iter()
425        .skip(code as _)
426        .map(|(_, count)| *count)
427        .sum();
428
429    // Left bit-width is determined based on the actual dictionary size.
430    let max_code = dictionary.len() - 1;
431    let left_bw = bit_width!(max_code) as u8;
432
433    (
434        ALPRDDictionary {
435            dictionary,
436            right_bit_width: right_bw,
437            left_bit_width: left_bw,
438        },
439        exception_count,
440    )
441}
442
443/// Estimate the bits-per-value when using these compression settings.
444fn estimate_compression_size(
445    right_bw: u8,
446    left_bw: u8,
447    exception_count: usize,
448    sample_n: usize,
449) -> f64 {
450    const EXC_POSITION_SIZE: usize = 16; // two bytes for exception position.
451    const EXC_SIZE: usize = 16; // two bytes for each exception (up to 16 front bits).
452
453    let exceptions_size = exception_count * (EXC_POSITION_SIZE + EXC_SIZE);
454    (right_bw as f64) + (left_bw as f64) + ((exceptions_size as f64) / (sample_n as f64))
455}
456
457/// The ALP-RD dictionary, encoding the "left parts" and their dictionary encoding.
458#[derive(Debug, Default)]
459struct ALPRDDictionary {
460    /// Items in the dictionary are bit patterns, along with their 16-bit encoding.
461    dictionary: HashMap<u16, u16, FxBuildHasher>,
462    /// The (compressed) left bit width. This is after bit-packing the dictionary codes.
463    left_bit_width: u8,
464    /// The right bit width. This is the bit-packed width of each of the "real double" values.
465    right_bit_width: u8,
466}