vortex_sampling_compressor/
sampling_compressor.rs

1use core::fmt::Formatter;
2use std::fmt::Display;
3
4use rand::rngs::StdRng;
5use rand::SeedableRng as _;
6use vortex_array::aliases::hash_set::HashSet;
7use vortex_array::array::{ChunkedArray, ConstantEncoding};
8use vortex_array::compress::{
9    check_dtype_unchanged, check_statistics_unchanged, check_validity_unchanged,
10    CompressionStrategy,
11};
12use vortex_array::compute::slice;
13use vortex_array::patches::Patches;
14use vortex_array::validity::Validity;
15use vortex_array::{Array, Encoding, EncodingId, IntoCanonical};
16use vortex_error::{VortexExpect as _, VortexResult};
17
18use super::compressors::chunked::DEFAULT_CHUNKED_COMPRESSOR;
19use super::compressors::struct_::StructCompressor;
20use super::{CompressConfig, Objective, DEFAULT_COMPRESSORS};
21use crate::compressors::constant::ConstantCompressor;
22use crate::compressors::{CompressedArray, CompressionTree, CompressorRef, EncodingCompressor};
23use crate::downscale::downscale_integer_array;
24use crate::sampling::stratified_slices;
25
26#[derive(Debug, Clone)]
27pub struct SamplingCompressor<'a> {
28    compressors: HashSet<CompressorRef<'a>>,
29    options: CompressConfig,
30
31    path: Vec<String>,
32    depth: u8,
33    /// A set of encodings disabled for this ctx.
34    disabled_compressors: HashSet<CompressorRef<'a>>,
35}
36
37impl Display for SamplingCompressor<'_> {
38    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39        write!(f, "[{}|{}]", self.depth, self.path.join("."))
40    }
41}
42
43impl CompressionStrategy for SamplingCompressor<'_> {
44    #[allow(clippy::same_name_method)]
45    fn compress(&self, array: &Array) -> VortexResult<Array> {
46        Self::compress(self, array, None).map(CompressedArray::into_array)
47    }
48
49    fn used_encodings(&self) -> HashSet<EncodingId> {
50        self.compressors
51            .iter()
52            .flat_map(|c| c.used_encodings())
53            .collect()
54    }
55}
56
57impl Default for SamplingCompressor<'_> {
58    fn default() -> Self {
59        Self::new(HashSet::from_iter(DEFAULT_COMPRESSORS))
60    }
61}
62
63impl<'a> SamplingCompressor<'a> {
64    pub fn new(compressors: impl Into<HashSet<CompressorRef<'a>>>) -> Self {
65        Self::new_with_options(compressors, Default::default())
66    }
67
68    pub fn new_with_options(
69        compressors: impl Into<HashSet<CompressorRef<'a>>>,
70        options: CompressConfig,
71    ) -> Self {
72        Self {
73            compressors: compressors.into(),
74            options,
75            path: Vec::new(),
76            depth: 0,
77            disabled_compressors: HashSet::new(),
78        }
79    }
80
81    pub fn named(&self, name: &str) -> Self {
82        let mut cloned = self.clone();
83        cloned.path.push(name.into());
84        cloned
85    }
86
87    // Returns a new ctx used for compressing an auxiliary array.
88    // In practice, this means resetting any disabled encodings back to the original config.
89    pub fn auxiliary(&self, name: &str) -> Self {
90        let mut cloned = self.clone();
91        cloned.path.push(name.into());
92        cloned.disabled_compressors = HashSet::new();
93        cloned
94    }
95
96    pub fn for_compressor(&self, compression: &dyn EncodingCompressor) -> Self {
97        let mut cloned = self.clone();
98        cloned.depth += compression.cost();
99        cloned
100    }
101
102    #[inline]
103    pub fn options(&self) -> &CompressConfig {
104        &self.options
105    }
106
107    pub fn excluding(&self, compressor: CompressorRef<'a>) -> Self {
108        let mut cloned = self.clone();
109        cloned.disabled_compressors.insert(compressor);
110        cloned
111    }
112
113    pub fn including(&self, compressor: CompressorRef<'a>) -> Self {
114        let mut cloned = self.clone();
115        cloned.compressors.insert(compressor);
116        cloned
117    }
118
119    pub fn including_only(&self, compressors: &[CompressorRef<'a>]) -> Self {
120        let mut cloned = self.clone();
121        cloned.compressors.clear();
122        cloned.compressors.extend(compressors);
123        cloned
124    }
125
126    pub fn is_enabled(&self, compressor: CompressorRef<'a>) -> bool {
127        self.compressors.contains(compressor) && !self.disabled_compressors.contains(compressor)
128    }
129
130    #[allow(clippy::same_name_method)]
131    pub fn compress(
132        &self,
133        arr: &Array,
134        like: Option<&CompressionTree<'a>>,
135    ) -> VortexResult<CompressedArray<'a>> {
136        if arr.is_empty() {
137            return Ok(CompressedArray::uncompressed(arr.clone()));
138        }
139
140        // Attempt to compress using the "like" array, otherwise fall back to sampled compression
141        if let Some(l) = like {
142            if let Some(compressed) = l.compress(arr, self) {
143                let compressed = compressed?;
144
145                check_validity_unchanged(arr, compressed.as_ref());
146                check_dtype_unchanged(arr, compressed.as_ref());
147                check_statistics_unchanged(arr, compressed.as_ref());
148                return Ok(compressed);
149            } else {
150                log::debug!("{} cannot compress {} like {}", self, arr, l);
151            }
152        }
153
154        // Otherwise, attempt to compress the array
155        let compressed = self.compress_array(arr)?;
156
157        check_validity_unchanged(arr, compressed.as_ref());
158        check_dtype_unchanged(arr, compressed.as_ref());
159        check_statistics_unchanged(arr, compressed.as_ref());
160        Ok(compressed)
161    }
162
163    pub fn compress_validity(&self, validity: Validity) -> VortexResult<Validity> {
164        match validity {
165            Validity::Array(a) => Ok(Validity::Array(self.compress(&a, None)?.into_array())),
166            a => Ok(a),
167        }
168    }
169
170    pub fn compress_patches(&self, patches: Patches) -> VortexResult<Patches> {
171        Ok(Patches::new(
172            patches.array_len(),
173            self.compress(&downscale_integer_array(patches.indices().clone())?, None)?
174                .into_array(),
175            self.compress(patches.values(), None)?.into_array(),
176        ))
177    }
178
179    pub(crate) fn compress_array(&self, array: &Array) -> VortexResult<CompressedArray<'a>> {
180        let mut rng = StdRng::seed_from_u64(self.options.rng_seed);
181
182        if array.is_encoding(ConstantEncoding::ID) {
183            // Not much better we can do than constant!
184            return Ok(CompressedArray::uncompressed(array.clone()));
185        }
186
187        if let Some(cc) = DEFAULT_CHUNKED_COMPRESSOR.can_compress(array) {
188            return cc.compress(array, None, self.clone());
189        }
190
191        if let Some(cc) = StructCompressor.can_compress(array) {
192            return cc.compress(array, None, self.clone());
193        }
194
195        // short-circuit because seriously nothing beats constant
196        if self.is_enabled(&ConstantCompressor) && ConstantCompressor.can_compress(array).is_some()
197        {
198            return ConstantCompressor.compress(array, None, self.clone());
199        }
200
201        let (mut candidates, too_deep) = self
202            .compressors
203            .iter()
204            .filter(|&encoding| !self.disabled_compressors.contains(encoding))
205            .filter(|&encoding| encoding.can_compress(array).is_some())
206            .partition::<Vec<&dyn EncodingCompressor>, _>(|&encoding| {
207                self.depth + encoding.cost() <= self.options.max_cost
208            });
209
210        if !too_deep.is_empty() {
211            log::debug!(
212                "{} skipping encodings due to depth/cost: {}",
213                self,
214                too_deep
215                    .iter()
216                    .map(|x| x.id())
217                    .collect::<Vec<_>>()
218                    .join(", ")
219            );
220        }
221
222        log::debug!("{} candidates for {}: {:?}", self, array, candidates);
223
224        if candidates.is_empty() {
225            log::debug!(
226                "{} no compressors for array with dtype: {} and encoding: {}",
227                self,
228                array.dtype(),
229                array.encoding(),
230            );
231            return Ok(CompressedArray::uncompressed(array.clone()));
232        }
233
234        // We prefer all other candidates to the array's own encoding.
235        // This is because we assume that the array's own encoding is the least efficient, but useful
236        // to destructure an array in the final stages of compression. e.g. VarBin would be DictEncoded
237        // but then the dictionary itself remains a VarBin array. DictEncoding excludes itself from the
238        // dictionary, but we still have a large offsets array that should be compressed.
239        // TODO(ngates): we actually probably want some way to prefer dict encoding over other varbin
240        //  encodings, e.g. FSST.
241        if candidates.len() > 1 {
242            candidates.retain(|&compression| compression.id() != array.encoding().as_ref());
243        }
244
245        if array.len() <= (self.options.sample_size as usize * self.options.sample_count as usize) {
246            // We're either already within a sample, or we're operating over a sufficiently small array.
247            return find_best_compression(candidates, array, self);
248        }
249
250        // Take a sample of the array, then ask codecs for their best compression estimate.
251        let sample = ChunkedArray::try_new(
252            stratified_slices(
253                array.len(),
254                self.options.sample_size,
255                self.options.sample_count,
256                &mut rng,
257            )
258            .into_iter()
259            .map(|(start, stop)| slice(array, start, stop))
260            .collect::<VortexResult<Vec<Array>>>()?,
261            array.dtype().clone(),
262        )?
263        .into_canonical()?
264        .into();
265
266        let best = find_best_compression(candidates, &sample, self)?
267            .into_path()
268            .map(|best_compressor| {
269                log::debug!(
270                    "{} Compressing array {} with {}",
271                    self,
272                    array,
273                    best_compressor
274                );
275                best_compressor.compress_unchecked(array, self)
276            })
277            .transpose()?;
278
279        Ok(best.unwrap_or_else(|| CompressedArray::uncompressed(array.clone())))
280    }
281}
282
283pub(crate) fn find_best_compression<'a>(
284    candidates: Vec<&'a dyn EncodingCompressor>,
285    sample: &Array,
286    ctx: &SamplingCompressor<'a>,
287) -> VortexResult<CompressedArray<'a>> {
288    let mut best = None;
289    let mut best_objective = ctx.options().objective.starting_value();
290    let mut best_objective_ratio = 1.0;
291    // for logging
292    let mut best_compression_ratio = 1.0;
293    let mut best_compression_ratio_sample = None;
294
295    for compression in candidates {
296        log::debug!(
297            "{} trying candidate {} for {}",
298            ctx,
299            compression.id(),
300            sample
301        );
302        if compression.can_compress(sample).is_none() {
303            continue;
304        }
305        let compressed_sample =
306            compression.compress(sample, None, ctx.for_compressor(compression))?;
307
308        let ratio = (compressed_sample.nbytes() as f64) / (sample.nbytes() as f64);
309        let objective = Objective::evaluate(&compressed_sample, sample.nbytes(), ctx.options());
310
311        // track the compression ratio, just for logging
312        if ratio < best_compression_ratio {
313            best_compression_ratio = ratio;
314
315            // if we find one with a better compression ratio but worse objective value, save it
316            // for debug logging later.
317            if ratio < best_objective_ratio && objective >= best_objective {
318                best_compression_ratio_sample = Some(compressed_sample.clone());
319            }
320        }
321
322        // don't consider anything that compresses to be *larger* than uncompressed
323        if objective < best_objective && ratio < 1.0 {
324            best_objective = objective;
325            best_objective_ratio = ratio;
326            best = Some(compressed_sample);
327        }
328
329        log::debug!(
330            "{} with {}: ratio ({}), objective fn value ({}); best so far: ratio ({}), objective fn value ({})",
331            ctx,
332            compression.id(),
333            ratio,
334            objective,
335            best_compression_ratio,
336            best_objective
337        );
338    }
339
340    let best = best.unwrap_or_else(|| CompressedArray::uncompressed(sample.clone()));
341    if best_compression_ratio < best_objective_ratio && best_compression_ratio_sample.is_some() {
342        let best_ratio_sample =
343            best_compression_ratio_sample.vortex_expect("already checked that this Option is Some");
344        log::debug!(
345            "{} best objective fn value ({}) has ratio {} from {}",
346            ctx,
347            best_objective,
348            best_compression_ratio,
349            best.array().tree_display()
350        );
351        log::debug!(
352            "{} best ratio ({}) has objective fn value {} from {}",
353            ctx,
354            best_compression_ratio,
355            best_objective,
356            best_ratio_sample.array().tree_display()
357        );
358    }
359
360    log::debug!(
361        "{} best compression ({} bytes, {} objective fn value, {} compression ratio",
362        ctx,
363        best.nbytes(),
364        best_objective,
365        best_compression_ratio,
366    );
367
368    Ok(best)
369}
370
371#[cfg(test)]
372mod tests {
373    use vortex_alp::ALPRDEncoding;
374    use vortex_array::array::PrimitiveArray;
375    use vortex_array::{Encoding, IntoArray};
376
377    use crate::SamplingCompressor;
378
379    #[test]
380    fn test_default() {
381        let array =
382            PrimitiveArray::from_iter((0..4096).map(|x| (x as f64) / 1234567890.0f64)).into_array();
383
384        let compressed = SamplingCompressor::default()
385            .compress(&array, None)
386            .unwrap()
387            .into_array();
388        assert_eq!(compressed.encoding(), ALPRDEncoding::ID);
389    }
390}