1mod dictionary;
2mod stats;
3
4use vortex_alp::{ALPArray, ALPEncoding, RDEncoder};
5use vortex_array::arrays::{ConstantArray, PrimitiveArray};
6use vortex_array::variants::PrimitiveArrayTrait;
7use vortex_array::vtable::EncodingVTable;
8use vortex_array::{Array, ArrayExt as _, ArrayRef, ArrayStatistics, ToCanonical};
9use vortex_dict::DictArray;
10use vortex_dtype::PType;
11use vortex_error::{VortexExpect, VortexResult, vortex_panic};
12use vortex_runend::RunEndArray;
13use vortex_runend::compress::runend_encode;
14
15use self::stats::FloatStats;
16use crate::float::dictionary::dictionary_encode;
17use crate::integer::{IntCompressor, IntegerStats};
18use crate::patches::compress_patches;
19use crate::{
20 Compressor, CompressorStats, GenerateStatsOptions, Scheme,
21 estimate_compression_ratio_with_sampling, integer,
22};
23
24const RUN_END_THRESHOLD: u32 = 3;
26
27pub trait FloatScheme: Scheme<StatsType = FloatStats, CodeType = FloatCode> {}
28
29impl<T> FloatScheme for T where T: Scheme<StatsType = FloatStats, CodeType = FloatCode> {}
30
31pub struct FloatCompressor;
32
33impl Compressor for FloatCompressor {
34 type ArrayType = PrimitiveArray;
35 type SchemeType = dyn FloatScheme;
36 type StatsType = FloatStats;
37
38 fn schemes() -> &'static [&'static Self::SchemeType] {
39 &[
40 &UncompressedScheme,
41 &ConstantScheme,
42 &ALPScheme,
43 &ALPRDScheme,
44 &DictScheme,
45 ]
46 }
47
48 fn default_scheme() -> &'static Self::SchemeType {
49 &UncompressedScheme
50 }
51
52 fn dict_scheme_code() -> FloatCode {
53 DICT_SCHEME
54 }
55}
56
57const UNCOMPRESSED_SCHEME: FloatCode = FloatCode(0);
58const CONSTANT_SCHEME: FloatCode = FloatCode(1);
59const ALP_SCHEME: FloatCode = FloatCode(2);
60const ALPRD_SCHEME: FloatCode = FloatCode(3);
61const DICT_SCHEME: FloatCode = FloatCode(4);
62const RUNEND_SCHEME: FloatCode = FloatCode(5);
63
64#[derive(Debug, Copy, Clone)]
65struct UncompressedScheme;
66
67#[derive(Debug, Copy, Clone)]
68struct ConstantScheme;
69
70#[derive(Debug, Copy, Clone)]
71struct ALPScheme;
72
73#[derive(Debug, Copy, Clone)]
74struct ALPRDScheme;
75
76#[derive(Debug, Copy, Clone)]
77struct DictScheme;
78
79#[derive(Debug, Copy, Clone)]
80struct RunEndScheme;
81
82impl Scheme for UncompressedScheme {
83 type StatsType = FloatStats;
84 type CodeType = FloatCode;
85
86 fn code(&self) -> FloatCode {
87 UNCOMPRESSED_SCHEME
88 }
89
90 fn expected_compression_ratio(
91 &self,
92 _stats: &Self::StatsType,
93 _is_sample: bool,
94 _allowed_cascading: usize,
95 _excludes: &[FloatCode],
96 ) -> VortexResult<f64> {
97 Ok(1.0)
98 }
99
100 fn compress(
101 &self,
102 stats: &Self::StatsType,
103 _is_sample: bool,
104 _allowed_cascading: usize,
105 _excludes: &[FloatCode],
106 ) -> VortexResult<ArrayRef> {
107 Ok(stats.source().to_array())
108 }
109}
110
111impl Scheme for ConstantScheme {
112 type StatsType = FloatStats;
113 type CodeType = FloatCode;
114
115 fn code(&self) -> FloatCode {
116 CONSTANT_SCHEME
117 }
118
119 fn expected_compression_ratio(
120 &self,
121 stats: &Self::StatsType,
122 is_sample: bool,
123 _allowed_cascading: usize,
124 _excludes: &[FloatCode],
125 ) -> VortexResult<f64> {
126 if is_sample {
128 return Ok(0.0);
129 }
130
131 if stats.distinct_values_count > 1 {
133 return Ok(0.0);
134 }
135
136 if stats.null_count > 0 && stats.value_count > 0 {
138 return Ok(0.0);
139 }
140
141 Ok(stats.value_count as f64)
142 }
143
144 fn compress(
145 &self,
146 stats: &Self::StatsType,
147 _is_sample: bool,
148 _allowed_cascading: usize,
149 _excludes: &[FloatCode],
150 ) -> VortexResult<ArrayRef> {
151 let scalar = stats
152 .source()
153 .as_constant()
154 .vortex_expect("must be constant");
155
156 Ok(ConstantArray::new(scalar, stats.source().len()).into_array())
157 }
158}
159
160#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
161pub struct FloatCode(u8);
162
163impl Scheme for ALPScheme {
164 type StatsType = FloatStats;
165 type CodeType = FloatCode;
166
167 fn code(&self) -> FloatCode {
168 ALP_SCHEME
169 }
170
171 fn expected_compression_ratio(
172 &self,
173 stats: &Self::StatsType,
174 is_sample: bool,
175 allowed_cascading: usize,
176 excludes: &[FloatCode],
177 ) -> VortexResult<f64> {
178 if stats.source().ptype() == PType::F16 {
180 return Ok(0.0);
181 }
182
183 if allowed_cascading == 0 {
184 return Ok(0.0);
187 }
188
189 estimate_compression_ratio_with_sampling(
190 self,
191 stats,
192 is_sample,
193 allowed_cascading,
194 excludes,
195 )
196 }
197
198 fn compress(
199 &self,
200 stats: &FloatStats,
201 is_sample: bool,
202 allowed_cascading: usize,
203 excludes: &[FloatCode],
204 ) -> VortexResult<ArrayRef> {
205 let alp_encoded = ALPEncoding
206 .encode(&stats.source().to_canonical()?, None)?
207 .vortex_expect("Input is a supported floating point array");
208 let alp = alp_encoded.as_::<ALPArray>();
209 let alp_ints = alp.encoded().to_primitive()?;
210
211 let mut int_excludes = Vec::new();
215 if excludes.contains(&DICT_SCHEME) {
216 int_excludes.push(integer::DictScheme.code());
217 }
218 if excludes.contains(&RUNEND_SCHEME) {
219 int_excludes.push(integer::RunEndScheme.code());
220 }
221
222 let compressed_alp_ints =
223 IntCompressor::compress(&alp_ints, is_sample, allowed_cascading - 1, &int_excludes)?;
224
225 let patches = alp.patches().map(compress_patches).transpose()?;
226
227 Ok(ALPArray::try_new(compressed_alp_ints, alp.exponents(), patches)?.into_array())
228 }
229}
230
231impl Scheme for ALPRDScheme {
232 type StatsType = FloatStats;
233 type CodeType = FloatCode;
234
235 fn code(&self) -> FloatCode {
236 ALPRD_SCHEME
237 }
238
239 fn expected_compression_ratio(
240 &self,
241 stats: &Self::StatsType,
242 is_sample: bool,
243 allowed_cascading: usize,
244 excludes: &[FloatCode],
245 ) -> VortexResult<f64> {
246 if stats.source().ptype() == PType::F16 {
247 return Ok(0.0);
248 }
249
250 estimate_compression_ratio_with_sampling(
251 self,
252 stats,
253 is_sample,
254 allowed_cascading,
255 excludes,
256 )
257 }
258
259 fn compress(
260 &self,
261 stats: &Self::StatsType,
262 _is_sample: bool,
263 _allowed_cascading: usize,
264 _excludes: &[FloatCode],
265 ) -> VortexResult<ArrayRef> {
266 let encoder = match stats.source().ptype() {
267 PType::F32 => RDEncoder::new(stats.source().as_slice::<f32>()),
268 PType::F64 => RDEncoder::new(stats.source().as_slice::<f64>()),
269 ptype => vortex_panic!("cannot ALPRD compress ptype {ptype}"),
270 };
271
272 let mut alp_rd = encoder.encode(stats.source());
273
274 let patches = alp_rd
275 .left_parts_patches()
276 .map(compress_patches)
277 .transpose()?;
278 alp_rd.replace_left_parts_patches(patches);
279
280 Ok(alp_rd.into_array())
281 }
282}
283
284impl Scheme for DictScheme {
285 type StatsType = FloatStats;
286 type CodeType = FloatCode;
287
288 fn code(&self) -> FloatCode {
289 DICT_SCHEME
290 }
291
292 fn expected_compression_ratio(
293 &self,
294 stats: &Self::StatsType,
295 is_sample: bool,
296 allowed_cascading: usize,
297 excludes: &[FloatCode],
298 ) -> VortexResult<f64> {
299 if stats.value_count == 0 {
300 return Ok(0.0);
301 }
302
303 if stats.distinct_values_count > stats.value_count / 2 {
305 return Ok(0.0);
306 }
307
308 estimate_compression_ratio_with_sampling(
310 self,
311 stats,
312 is_sample,
313 allowed_cascading,
314 excludes,
315 )
316 }
317
318 fn compress(
319 &self,
320 stats: &Self::StatsType,
321 is_sample: bool,
322 allowed_cascading: usize,
323 _excludes: &[FloatCode],
324 ) -> VortexResult<ArrayRef> {
325 let dict_array = dictionary_encode(stats)?;
326
327 let codes_stats = IntegerStats::generate_opts(
329 &dict_array.codes().to_primitive()?,
330 GenerateStatsOptions {
331 count_distinct_values: false,
332 },
333 );
334 let codes_scheme = IntCompressor::choose_scheme(
335 &codes_stats,
336 is_sample,
337 allowed_cascading - 1,
338 &[integer::DictScheme.code()],
339 )?;
340 let compressed_codes = codes_scheme.compress(
341 &codes_stats,
342 is_sample,
343 allowed_cascading - 1,
344 &[integer::DictScheme.code()],
345 )?;
346
347 let compressed_values = FloatCompressor::compress(
348 &dict_array.values().to_primitive()?,
349 is_sample,
350 allowed_cascading - 1,
351 &[DICT_SCHEME],
352 )?;
353
354 Ok(DictArray::try_new(compressed_codes, compressed_values)?.into_array())
355 }
356}
357
358impl Scheme for RunEndScheme {
359 type StatsType = FloatStats;
360 type CodeType = FloatCode;
361
362 fn code(&self) -> FloatCode {
363 RUNEND_SCHEME
364 }
365
366 fn expected_compression_ratio(
367 &self,
368 stats: &Self::StatsType,
369 is_sample: bool,
370 allowed_cascading: usize,
371 excludes: &[FloatCode],
372 ) -> VortexResult<f64> {
373 if stats.average_run_length < RUN_END_THRESHOLD {
374 return Ok(0.0);
375 }
376
377 estimate_compression_ratio_with_sampling(
378 self,
379 stats,
380 is_sample,
381 allowed_cascading,
382 excludes,
383 )
384 }
385
386 fn compress(
387 &self,
388 stats: &FloatStats,
389 is_sample: bool,
390 allowed_cascading: usize,
391 _excludes: &[FloatCode],
392 ) -> VortexResult<ArrayRef> {
393 let (ends, values) = runend_encode(stats.source())?;
394 let compressed_ends = IntCompressor::compress(
396 &ends,
397 is_sample,
398 allowed_cascading - 1,
399 &[
400 integer::RunEndScheme.code(),
401 integer::DictScheme.code(),
402 integer::SparseScheme.code(),
403 ],
404 )?;
405
406 Ok(RunEndArray::try_new(compressed_ends, values)?.into_array())
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use vortex_array::arrays::PrimitiveArray;
413 use vortex_array::validity::Validity;
414 use vortex_array::{Array, IntoArray, ToCanonical};
415 use vortex_buffer::{Buffer, buffer_mut};
416
417 use crate::float::FloatCompressor;
418 use crate::{Compressor, MAX_CASCADE};
419
420 #[test]
421 fn test_empty() {
422 let result = FloatCompressor::compress(
424 &PrimitiveArray::new(Buffer::<f32>::empty(), Validity::NonNullable),
425 false,
426 3,
427 &[],
428 )
429 .unwrap();
430
431 assert!(result.is_empty());
432 }
433
434 #[test]
435 fn test_compress() {
436 let mut values = buffer_mut![1.0f32; 1024];
437 for i in 0..1024 {
439 values[i] = (i % 50) as f32;
443 }
444
445 let floats = values.into_array().to_primitive().unwrap();
446 let compressed = FloatCompressor::compress(&floats, false, MAX_CASCADE, &[]).unwrap();
447 println!("compressed: {}", compressed.tree_display())
448 }
449}