vortex_layout/layouts/
compact.rs1use vortex_array::arrays::{
5 ExtensionArray, FixedSizeListArray, ListViewArray, PrimitiveArray, StructArray,
6 narrowed_decimal,
7};
8use vortex_array::vtable::ValidityHelper;
9use vortex_array::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
10use vortex_decimal_byte_parts::DecimalBytePartsArray;
11use vortex_dtype::PType;
12use vortex_error::VortexResult;
13use vortex_pco::PcoArray;
14use vortex_scalar::DecimalType;
15use vortex_zstd::ZstdArray;
16
17fn is_pco_number_type(ptype: PType) -> bool {
18 matches!(
19 ptype,
20 PType::F16
21 | PType::F32
22 | PType::F64
23 | PType::I16
24 | PType::I32
25 | PType::I64
26 | PType::U16
27 | PType::U32
28 | PType::U64
29 )
30}
31
32#[derive(Debug, Clone)]
36pub struct CompactCompressor {
37 pco_level: usize,
38 zstd_level: i32,
39 values_per_page: usize,
40}
41
42impl CompactCompressor {
43 pub fn with_pco_level(mut self, level: usize) -> Self {
44 self.pco_level = level;
45 self
46 }
47
48 pub fn with_zstd_level(mut self, level: i32) -> Self {
49 self.zstd_level = level;
50 self
51 }
52
53 pub fn with_values_per_page(mut self, values_per_page: usize) -> Self {
60 self.values_per_page = values_per_page;
61 self
62 }
63
64 pub fn compress(&self, array: &dyn Array) -> VortexResult<ArrayRef> {
65 self.compress_canonical(array.to_canonical())
66 }
67
68 pub fn compress_canonical(&self, canonical: Canonical) -> VortexResult<ArrayRef> {
70 let uncompressed_nbytes = canonical.as_ref().nbytes();
71 let compressed = match &canonical {
72 Canonical::Primitive(primitive) => {
74 let ptype = primitive.ptype();
76
77 if is_pco_number_type(ptype) {
78 let pco_array =
79 PcoArray::from_primitive(primitive, self.pco_level, self.values_per_page)?;
80 pco_array.into_array()
81 } else {
82 let zstd_array = ZstdArray::from_primitive(
83 primitive,
84 self.zstd_level,
85 self.values_per_page,
86 )?;
87 zstd_array.into_array()
88 }
89 }
90 Canonical::Decimal(decimal) => {
91 let decimal = narrowed_decimal(decimal.clone());
92 let validity = decimal.validity();
93 let int_values = match decimal.values_type() {
94 DecimalType::I8 => {
95 PrimitiveArray::new(decimal.buffer::<i8>(), validity.clone())
96 }
97 DecimalType::I16 => {
98 PrimitiveArray::new(decimal.buffer::<i16>(), validity.clone())
99 }
100 DecimalType::I32 => {
101 PrimitiveArray::new(decimal.buffer::<i32>(), validity.clone())
102 }
103 DecimalType::I64 => {
104 PrimitiveArray::new(decimal.buffer::<i64>(), validity.clone())
105 }
106 _ => {
107 return Ok(canonical.into_array());
109 }
110 };
111 let compressed = self.compress_canonical(Canonical::Primitive(int_values))?;
112 DecimalBytePartsArray::try_new(compressed, decimal.decimal_dtype())?.to_array()
113 }
114 Canonical::VarBinView(vbv) => {
115 ZstdArray::from_var_bin_view(vbv, self.zstd_level, self.values_per_page)?
117 .into_array()
118 }
119 Canonical::Struct(struct_array) => {
120 let fields = struct_array
122 .fields()
123 .iter()
124 .map(|field| self.compress(field))
125 .collect::<VortexResult<Vec<_>>>()?;
126
127 StructArray::try_new(
128 struct_array.names().clone(),
129 fields,
130 struct_array.len(),
131 struct_array.validity().clone(),
132 )?
133 .into_array()
134 }
135 Canonical::List(listview) => {
136 let compressed_elems = self.compress(listview.elements())?;
137
138 let compressed_offsets =
141 self.compress(&listview.offsets().to_primitive().narrow()?.into_array())?;
142 let compressed_sizes =
143 self.compress(&listview.sizes().to_primitive().narrow()?.into_array())?;
144
145 unsafe {
150 ListViewArray::new_unchecked(
151 compressed_elems,
152 compressed_offsets,
153 compressed_sizes,
154 listview.validity().clone(),
155 )
156 .with_zero_copy_to_list(listview.is_zero_copy_to_list())
157 }
158 .into_array()
159 }
160 Canonical::FixedSizeList(fsl) => {
161 let compressed_elems = self.compress(fsl.elements())?;
162
163 FixedSizeListArray::try_new(
164 compressed_elems,
165 fsl.list_size(),
166 fsl.validity().clone(),
167 fsl.len(),
168 )?
169 .into_array()
170 }
171 Canonical::Extension(ext_array) => {
172 let compressed_storage = self.compress(ext_array.storage())?;
173
174 ExtensionArray::new(ext_array.ext_dtype().clone(), compressed_storage).into_array()
175 }
176 _ => return Ok(canonical.into_array()),
177 };
178
179 if compressed.nbytes() >= uncompressed_nbytes {
180 return Ok(canonical.into_array());
181 }
182 Ok(compressed)
183 }
184}
185
186impl Default for CompactCompressor {
187 fn default() -> Self {
188 Self {
189 pco_level: pco::DEFAULT_COMPRESSION_LEVEL,
190 zstd_level: 3,
191 values_per_page: 8192,
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use vortex_array::arrays::{PrimitiveArray, StructArray};
203 use vortex_array::validity::Validity;
204 use vortex_array::{IntoArray, ToCanonical, assert_arrays_eq};
205 use vortex_buffer::buffer;
206 use vortex_dtype::FieldName;
207
208 use super::*;
209
210 #[test]
211 fn test_compact_compressor_struct_with_mixed_types() {
212 let compressor = CompactCompressor::default();
213
214 let columns = vec![
216 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable),
218 PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50], Validity::NonNullable),
219 PrimitiveArray::new(buffer![11u8, 22, 33, 44, 55], Validity::NonNullable),
221 ]
222 .iter()
223 .map(|a| a.clone().into_array())
224 .collect::<Vec<_>>();
225 let field_names: Vec<FieldName> =
226 vec!["f64_field".into(), "i32_field".into(), "u8_field".into()];
227
228 let n_rows = columns[0].len();
229 let struct_array = StructArray::try_new(
230 field_names.clone().into(),
231 columns.clone(),
232 n_rows,
233 Validity::NonNullable,
234 )
235 .unwrap();
236
237 let compressed = compressor.compress(struct_array.as_ref()).unwrap();
239
240 let decompressed = compressed.to_canonical().into_array();
242 assert_eq!(decompressed.len(), n_rows);
243 let decompressed_struct = decompressed.to_struct();
244
245 for (i, name) in decompressed_struct.names().iter().enumerate() {
247 assert_eq!(name, field_names[i]);
248 let decompressed_array = decompressed_struct.field_by_name(name).unwrap().clone();
249 assert_eq!(decompressed_array.len(), n_rows);
250
251 assert_arrays_eq!(decompressed_array.as_ref(), columns[i].as_ref());
252 }
253 }
254}