vortex_compressor/builtins/dict/
float.rs1use vortex_array::ArrayRef;
10use vortex_array::ArrayView;
11use vortex_array::Canonical;
12use vortex_array::ExecutionCtx;
13use vortex_array::IntoArray;
14use vortex_array::arrays::DictArray;
15use vortex_array::arrays::Primitive;
16use vortex_array::arrays::PrimitiveArray;
17use vortex_array::arrays::dict::DictArrayExt;
18use vortex_array::arrays::dict::DictArraySlotsExt;
19use vortex_array::arrays::primitive::PrimitiveArrayExt;
20use vortex_array::dtype::half::f16;
21use vortex_array::validity::Validity;
22use vortex_buffer::Buffer;
23use vortex_error::VortexExpect;
24use vortex_error::VortexResult;
25
26use crate::CascadingCompressor;
27use crate::builtins::IntDictScheme;
28use crate::ctx::CompressorContext;
29use crate::estimate::CompressionEstimate;
30use crate::estimate::DeferredEstimate;
31use crate::estimate::EstimateVerdict;
32use crate::scheme::ChildSelection;
33use crate::scheme::DescendantExclusion;
34use crate::scheme::Scheme;
35use crate::scheme::SchemeExt;
36use crate::stats::ArrayAndStats;
37use crate::stats::FloatErasedStats;
38use crate::stats::FloatStats;
39use crate::stats::GenerateStatsOptions;
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43pub struct FloatDictScheme;
44
45impl Scheme for FloatDictScheme {
46 fn scheme_name(&self) -> &'static str {
47 "vortex.float.dict"
48 }
49
50 fn matches(&self, canonical: &Canonical) -> bool {
51 canonical.dtype().is_float()
52 }
53
54 fn stats_options(&self) -> GenerateStatsOptions {
55 GenerateStatsOptions {
56 count_distinct_values: true,
57 }
58 }
59
60 fn num_children(&self) -> usize {
62 2
63 }
64
65 fn descendant_exclusions(&self) -> Vec<DescendantExclusion> {
74 vec![
75 DescendantExclusion {
76 excluded: IntDictScheme.id(),
77 children: ChildSelection::One(1),
78 },
79 DescendantExclusion {
80 excluded: IntDictScheme.id(),
81 children: ChildSelection::One(0),
82 },
83 ]
84 }
85
86 fn expected_compression_ratio(
87 &self,
88 data: &ArrayAndStats,
89 _compress_ctx: CompressorContext,
90 exec_ctx: &mut ExecutionCtx,
91 ) -> CompressionEstimate {
92 let stats = data.float_stats(exec_ctx);
93
94 if stats.value_count() == 0 {
95 return CompressionEstimate::Verdict(EstimateVerdict::Skip);
96 }
97
98 let distinct_values_count = stats.distinct_count().vortex_expect(
99 "this must be present since `DictScheme` declared that we need distinct values",
100 );
101
102 if distinct_values_count > stats.value_count() / 2 {
104 return CompressionEstimate::Verdict(EstimateVerdict::Skip);
105 }
106
107 CompressionEstimate::Deferred(DeferredEstimate::Sample)
109 }
110
111 fn compress(
112 &self,
113 compressor: &CascadingCompressor,
114 data: &ArrayAndStats,
115 compress_ctx: CompressorContext,
116 exec_ctx: &mut ExecutionCtx,
117 ) -> VortexResult<ArrayRef> {
118 let stats = data.float_stats(exec_ctx);
119 let dict = dictionary_encode(data.array_as_primitive(), &stats)?;
120
121 let has_all_values_referenced = dict.has_all_values_referenced();
122
123 let compressed_values =
125 compressor.compress_child(dict.values(), &compress_ctx, self.id(), 0, exec_ctx)?;
126
127 let narrowed_codes = dict
129 .codes()
130 .clone()
131 .execute::<PrimitiveArray>(exec_ctx)?
132 .narrow(exec_ctx)?
133 .into_array();
134 let compressed_codes =
135 compressor.compress_child(&narrowed_codes, &compress_ctx, self.id(), 1, exec_ctx)?;
136
137 unsafe {
139 Ok(
140 DictArray::new_unchecked(compressed_codes, compressed_values)
141 .set_all_values_referenced(has_all_values_referenced)
142 .into_array(),
143 )
144 }
145 }
146}
147
148macro_rules! typed_encode {
150 ($source_array:ident, $stats:ident, $typed:ident, $typ:ty) => {{
151 let distinct = $typed.distinct().vortex_expect(
152 "this must be present since `DictScheme` declared that we need distinct values",
153 );
154
155 let values_validity = match $source_array.validity()? {
156 Validity::NonNullable => Validity::NonNullable,
157 _ => Validity::AllValid,
158 };
159 let codes_validity = $source_array.validity()?;
160
161 let values: Buffer<$typ> = distinct.distinct_values().iter().map(|x| x.0).collect();
162
163 let max_code = values.len();
164 let codes = if max_code <= u8::MAX as usize {
165 let buf = <DictEncoder as Encode<$typ, u8>>::encode(
166 &values,
167 $source_array.as_slice::<$typ>(),
168 );
169 PrimitiveArray::new(buf, codes_validity).into_array()
170 } else if max_code <= u16::MAX as usize {
171 let buf = <DictEncoder as Encode<$typ, u16>>::encode(
172 &values,
173 $source_array.as_slice::<$typ>(),
174 );
175 PrimitiveArray::new(buf, codes_validity).into_array()
176 } else {
177 let buf = <DictEncoder as Encode<$typ, u32>>::encode(
178 &values,
179 $source_array.as_slice::<$typ>(),
180 );
181 PrimitiveArray::new(buf, codes_validity).into_array()
182 };
183
184 let values = PrimitiveArray::new(values, values_validity).into_array();
185 Ok(unsafe { DictArray::new_unchecked(codes, values).set_all_values_referenced(true) })
187 }};
188}
189
190pub fn dictionary_encode(
196 array: ArrayView<'_, Primitive>,
197 stats: &FloatStats,
198) -> VortexResult<DictArray> {
199 match stats.erased() {
200 FloatErasedStats::F16(typed) => typed_encode!(array, stats, typed, f16),
201 FloatErasedStats::F32(typed) => typed_encode!(array, stats, typed, f32),
202 FloatErasedStats::F64(typed) => typed_encode!(array, stats, typed, f64),
203 }
204}
205
206struct DictEncoder;
208
209trait Encode<T, I> {
211 fn encode(distinct: &[T], values: &[T]) -> Buffer<I>;
213}
214
215macro_rules! impl_encode {
217 ($typ:ty, $utyp:ty) => { impl_encode!($typ, $utyp, u8, u16, u32); };
218 ($typ:ty, $utyp:ty, $($ityp:ty),+) => {
219 $(
220 impl Encode<$typ, $ityp> for DictEncoder {
221 #[expect(clippy::cast_possible_truncation)]
222 fn encode(distinct: &[$typ], values: &[$typ]) -> Buffer<$ityp> {
223 let mut codes =
224 vortex_utils::aliases::hash_map::HashMap::<$utyp, $ityp>::with_capacity(
225 distinct.len(),
226 );
227 for (code, &value) in distinct.iter().enumerate() {
228 codes.insert(value.to_bits(), code as $ityp);
229 }
230
231 let mut output = vortex_buffer::BufferMut::with_capacity(values.len());
232 for value in values {
233 output.push(codes.get(&value.to_bits()).copied().unwrap_or_default());
235 }
236
237 output.freeze()
238 }
239 }
240 )*
241 };
242}
243
244impl_encode!(f16, u16);
245impl_encode!(f32, u32);
246impl_encode!(f64, u64);
247
248#[cfg(test)]
249mod tests {
250 use vortex_array::IntoArray;
251 use vortex_array::VortexSessionExecute;
252 use vortex_array::arrays::BoolArray;
253 use vortex_array::arrays::PrimitiveArray;
254 use vortex_array::arrays::dict::DictArraySlotsExt;
255 use vortex_array::assert_arrays_eq;
256 use vortex_array::session::ArraySession;
257 use vortex_array::validity::Validity;
258 use vortex_buffer::buffer;
259 use vortex_error::VortexResult;
260 use vortex_session::VortexSession;
261
262 use super::dictionary_encode;
263 use crate::stats::FloatStats;
264 use crate::stats::GenerateStatsOptions;
265
266 #[test]
267 fn test_float_dict_encode() -> VortexResult<()> {
268 let mut ctx = VortexSession::empty()
269 .with::<ArraySession>()
270 .create_execution_ctx();
271 let values = buffer![1f32, 2f32, 2f32, 0f32, 1f32];
272 let validity =
273 Validity::Array(BoolArray::from_iter([true, true, true, false, true]).into_array());
274 let array = PrimitiveArray::new(values, validity);
275
276 let stats = FloatStats::generate_opts(
277 &array,
278 GenerateStatsOptions {
279 count_distinct_values: true,
280 },
281 &mut ctx,
282 );
283 let dict_array = dictionary_encode(array.as_view(), &stats)?;
284 assert_eq!(dict_array.values().len(), 2);
285 assert_eq!(dict_array.codes().len(), 5);
286
287 let expected = PrimitiveArray::new(
288 buffer![1f32, 2f32, 2f32, 1f32, 1f32],
289 Validity::Array(BoolArray::from_iter([true, true, true, false, true]).into_array()),
290 )
291 .into_array();
292 let undict = dict_array
293 .as_array()
294 .clone()
295 .execute::<PrimitiveArray>(&mut ctx)?
296 .into_array();
297 assert_arrays_eq!(undict, expected);
298 Ok(())
299 }
300}