vortex_layout/layouts/
compact.rs1use vortex_array::arrays::{
5 ExtensionArray, FixedSizeListArray, ListViewArray, PrimitiveArray, StructArray,
6 narrowed_decimal,
7};
8use vortex_array::vtable::ValidityHelper;
9use vortex_array::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
10use vortex_decimal_byte_parts::DecimalBytePartsArray;
11use vortex_dtype::PType;
12use vortex_error::VortexResult;
13use vortex_pco::PcoArray;
14use vortex_scalar::DecimalValueType;
15use vortex_zstd::ZstdArray;
16
17fn is_pco_number_type(ptype: PType) -> bool {
18 matches!(
19 ptype,
20 PType::F16
21 | PType::F32
22 | PType::F64
23 | PType::I16
24 | PType::I32
25 | PType::I64
26 | PType::U16
27 | PType::U32
28 | PType::U64
29 )
30}
31
32#[derive(Debug, Clone)]
36pub struct CompactCompressor {
37 pco_level: usize,
38 zstd_level: i32,
39 values_per_page: usize,
40}
41
42impl CompactCompressor {
43 pub fn with_pco_level(mut self, level: usize) -> Self {
44 self.pco_level = level;
45 self
46 }
47
48 pub fn with_zstd_level(mut self, level: i32) -> Self {
49 self.zstd_level = level;
50 self
51 }
52
53 pub fn with_values_per_page(mut self, values_per_page: usize) -> Self {
60 self.values_per_page = values_per_page;
61 self
62 }
63
64 pub fn compress(&self, array: &dyn Array) -> VortexResult<ArrayRef> {
65 self.compress_canonical(array.to_canonical())
66 }
67
68 pub fn compress_canonical(&self, canonical: Canonical) -> VortexResult<ArrayRef> {
70 let uncompressed_nbytes = canonical.as_ref().nbytes();
71 let compressed = match &canonical {
72 Canonical::Primitive(primitive) => {
74 let ptype = primitive.ptype();
76
77 if is_pco_number_type(ptype) {
78 let pco_array =
79 PcoArray::from_primitive(primitive, self.pco_level, self.values_per_page)?;
80 pco_array.into_array()
81 } else {
82 let zstd_array = ZstdArray::from_primitive(
83 primitive,
84 self.zstd_level,
85 self.values_per_page,
86 )?;
87 zstd_array.into_array()
88 }
89 }
90 Canonical::Decimal(decimal) => {
91 let decimal = narrowed_decimal(decimal.clone());
92 let validity = decimal.validity();
93 let int_values = match decimal.values_type() {
94 DecimalValueType::I8 => {
95 PrimitiveArray::new(decimal.buffer::<i8>(), validity.clone())
96 }
97 DecimalValueType::I16 => {
98 PrimitiveArray::new(decimal.buffer::<i16>(), validity.clone())
99 }
100 DecimalValueType::I32 => {
101 PrimitiveArray::new(decimal.buffer::<i32>(), validity.clone())
102 }
103 DecimalValueType::I64 => {
104 PrimitiveArray::new(decimal.buffer::<i64>(), validity.clone())
105 }
106 _ => {
107 return Ok(canonical.into_array());
109 }
110 };
111 let compressed = self.compress_canonical(Canonical::Primitive(int_values))?;
112 DecimalBytePartsArray::try_new(compressed, decimal.decimal_dtype())?.to_array()
113 }
114 Canonical::VarBinView(vbv) => {
115 ZstdArray::from_var_bin_view(vbv, self.zstd_level, self.values_per_page)?
117 .into_array()
118 }
119 Canonical::Struct(struct_array) => {
120 let fields = struct_array
122 .fields()
123 .iter()
124 .map(|field| self.compress(field))
125 .collect::<VortexResult<Vec<_>>>()?;
126
127 StructArray::try_new(
128 struct_array.names().clone(),
129 fields,
130 struct_array.len(),
131 struct_array.validity().clone(),
132 )?
133 .into_array()
134 }
135 Canonical::List(list_array) => {
136 let compressed_elems = self.compress(list_array.elements())?;
137
138 let compressed_offsets =
141 self.compress(&list_array.offsets().to_primitive().narrow()?.into_array())?;
142 let compressed_sizes =
143 self.compress(&list_array.sizes().to_primitive().narrow()?.into_array())?;
144
145 ListViewArray::try_new(
146 compressed_elems,
147 compressed_offsets,
148 compressed_sizes,
149 list_array.validity().clone(),
150 )?
151 .into_array()
152 }
153 Canonical::FixedSizeList(list_array) => {
154 let compressed_elems = self.compress(list_array.elements())?;
155
156 FixedSizeListArray::try_new(
157 compressed_elems,
158 list_array.list_size(),
159 list_array.validity().clone(),
160 list_array.len(),
161 )?
162 .into_array()
163 }
164 Canonical::Extension(ext_array) => {
165 let compressed_storage = self.compress(ext_array.storage())?;
166
167 ExtensionArray::new(ext_array.ext_dtype().clone(), compressed_storage).into_array()
168 }
169 _ => return Ok(canonical.into_array()),
170 };
171
172 if compressed.nbytes() >= uncompressed_nbytes {
173 return Ok(canonical.into_array());
174 }
175 Ok(compressed)
176 }
177}
178
179impl Default for CompactCompressor {
180 fn default() -> Self {
181 Self {
182 pco_level: pco::DEFAULT_COMPRESSION_LEVEL,
183 zstd_level: 3,
184 values_per_page: 8192,
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use vortex_array::arrays::{PrimitiveArray, StructArray};
196 use vortex_array::validity::Validity;
197 use vortex_array::{IntoArray, ToCanonical};
198 use vortex_buffer::buffer;
199 use vortex_dtype::FieldName;
200
201 use super::*;
202
203 #[test]
204 fn test_compact_compressor_struct_with_mixed_types() {
205 let compressor = CompactCompressor::default();
206
207 let columns = vec![
209 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable),
211 PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50], Validity::NonNullable),
212 PrimitiveArray::new(buffer![11u8, 22, 33, 44, 55], Validity::NonNullable),
214 ]
215 .iter()
216 .map(|a| a.clone().into_array())
217 .collect::<Vec<_>>();
218 let field_names: Vec<FieldName> =
219 vec!["f64_field".into(), "i32_field".into(), "u8_field".into()];
220
221 let n_rows = columns[0].len();
222 let struct_array = StructArray::try_new(
223 field_names.clone().into(),
224 columns.clone(),
225 n_rows,
226 Validity::NonNullable,
227 )
228 .unwrap();
229
230 let compressed = compressor.compress(struct_array.as_ref()).unwrap();
232
233 let decompressed = compressed.to_canonical().into_array();
235 assert_eq!(decompressed.len(), n_rows);
236 let decompressed_struct = decompressed.to_struct();
237
238 for (i, name) in decompressed_struct.names().iter().enumerate() {
240 assert_eq!(name, field_names[i]);
241 let decompressed_array = decompressed_struct
242 .field_by_name(name)
243 .unwrap()
244 .to_primitive();
245 assert_eq!(decompressed_array.len(), n_rows);
247
248 for j in 0..n_rows {
249 assert_eq!(decompressed_array.scalar_at(j), columns[i].scalar_at(j),);
250 }
251 }
252 }
253}