1use itertools::Itertools;
2use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, StructArray};
3use vortex_array::patches::Patches;
4use vortex_array::validity::Validity;
5use vortex_array::vtable::CanonicalVTable;
6use vortex_array::{Array, Canonical};
7use vortex_buffer::buffer;
8use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseVTable};
13
14impl CanonicalVTable<SparseVTable> for SparseVTable {
15 fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
16 if array.patches().num_patches() == 0 {
17 return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18 }
19
20 match array.dtype() {
21 DType::Bool(..) => {
22 let resolved_patches = array.resolved_patches()?;
23 canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
24 }
25 DType::Primitive(ptype, ..) => {
26 let resolved_patches = array.resolved_patches()?;
27 match_each_native_ptype!(ptype, |P| {
28 canonicalize_sparse_primitives::<P>(&resolved_patches, array.fill_scalar())
29 })
30 }
31 DType::Struct(struct_fields, ..) => {
32 let fill_struct = array.fill_scalar().as_struct();
33 let (fill_values, top_level_fill_validity) = match fill_struct.fields() {
34 Some(fill_values) => (fill_values, Validity::AllValid),
35 None => (
36 struct_fields.fields().map(Scalar::default_value).collect(),
37 Validity::AllInvalid,
38 ),
39 };
40 let patches = array.patches();
42 let patch_values_as_struct = patches.values().to_canonical()?.into_struct()?;
43 let columns_patch_values = patch_values_as_struct.fields();
44 let names = patch_values_as_struct.names();
45 let validity = if array.dtype().is_nullable() {
46 top_level_fill_validity.patch(
47 array.len(),
48 patches.offset(),
49 patches.indices(),
50 &Validity::from_mask(
51 patches.values().validity_mask()?,
52 Nullability::Nullable,
53 ),
54 )?
55 } else {
56 top_level_fill_validity.into_non_nullable().ok_or_else(|| {
57 vortex_err!("fill validity should match sparse array nullability")
58 })?
59 };
60 columns_patch_values
61 .iter()
62 .cloned()
63 .zip_eq(fill_values.into_iter())
64 .map(|(patch_values, fill_value)| -> VortexResult<_> {
65 SparseArray::try_new_from_patches(
66 patches.clone().map_values(|_| Ok(patch_values))?,
67 fill_value,
68 )
69 })
70 .process_results(|sparse_columns| {
71 StructArray::try_from_iter_with_validity(
72 names.iter().zip_eq(sparse_columns),
73 validity,
74 )
75 .map(Canonical::Struct)
76 })?
77 }
78
79 dtype => vortex_bail!("unsupported type: {}", dtype),
80 }
81 }
82}
83
84fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
85 let (fill_bool, validity) = if fill_value.is_null() {
86 (false, Validity::AllInvalid)
87 } else {
88 (
89 fill_value.try_into()?,
90 if patches.dtype().nullability() == Nullability::NonNullable {
91 Validity::NonNullable
92 } else {
93 Validity::AllValid
94 },
95 )
96 };
97
98 let bools = BoolArray::new(
99 if fill_bool {
100 BooleanBuffer::new_set(patches.array_len())
101 } else {
102 BooleanBuffer::new_unset(patches.array_len())
103 },
104 validity,
105 );
106
107 bools.patch(patches).map(Canonical::Bool)
108}
109
110fn canonicalize_sparse_primitives<
111 T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
112>(
113 patches: &Patches,
114 fill_value: &Scalar,
115) -> VortexResult<Canonical> {
116 let (primitive_fill, validity) = if fill_value.is_null() {
117 (T::default(), Validity::AllInvalid)
118 } else {
119 (
120 fill_value.try_into()?,
121 if patches.dtype().nullability() == Nullability::NonNullable {
122 Validity::NonNullable
123 } else {
124 Validity::AllValid
125 },
126 )
127 };
128
129 let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
130
131 parray.patch(patches).map(Canonical::Primitive)
132}
133
134#[cfg(test)]
135mod test {
136
137 use rstest::rstest;
138 use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray, StructArray};
139 use vortex_array::arrow::IntoArrowArray as _;
140 use vortex_array::validity::Validity;
141 use vortex_array::vtable::ValidityHelper;
142 use vortex_array::{IntoArray, ToCanonical};
143 use vortex_buffer::buffer;
144 use vortex_dtype::Nullability::Nullable;
145 use vortex_dtype::{DType, FieldNames, PType, StructFields};
146 use vortex_mask::Mask;
147 use vortex_scalar::Scalar;
148
149 use crate::SparseArray;
150
151 #[rstest]
152 #[case(Some(true))]
153 #[case(Some(false))]
154 #[case(None)]
155 fn test_sparse_bool(#[case] fill_value: Option<bool>) {
156 let indices = buffer![0u64, 1, 7].into_array();
157 let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
158 .into_array();
159 let sparse_bools =
160 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
161 assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullable));
162
163 let flat_bools = sparse_bools.to_bool().unwrap();
164 let expected = bool_array_from_nullable_vec(
165 vec![
166 Some(true),
167 None,
168 fill_value,
169 fill_value,
170 fill_value,
171 fill_value,
172 fill_value,
173 Some(false),
174 fill_value,
175 fill_value,
176 ],
177 fill_value,
178 );
179
180 assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
181 assert_eq!(flat_bools.validity(), expected.validity());
182
183 assert!(flat_bools.boolean_buffer().value(0));
184 assert!(flat_bools.validity().is_valid(0).unwrap());
185 assert_eq!(
186 flat_bools.boolean_buffer().value(1),
187 fill_value.unwrap_or_default()
188 );
189 assert!(!flat_bools.validity().is_valid(1).unwrap());
190 assert_eq!(
191 flat_bools.validity().is_valid(2).unwrap(),
192 fill_value.is_some()
193 );
194 assert!(!flat_bools.boolean_buffer().value(7));
195 assert!(flat_bools.validity().is_valid(7).unwrap());
196 }
197
198 fn bool_array_from_nullable_vec(
199 bools: Vec<Option<bool>>,
200 fill_value: Option<bool>,
201 ) -> BoolArray {
202 let mut buffer = BooleanBufferBuilder::new(bools.len());
203 let mut validity = BooleanBufferBuilder::new(bools.len());
204 for maybe_bool in bools {
205 buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
206 validity.append(maybe_bool.is_some());
207 }
208 BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
209 }
210
211 #[rstest]
212 #[case(Some(0i32))]
213 #[case(Some(-1i32))]
214 #[case(None)]
215 fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
216 let indices = buffer![0u64, 1, 7].into_array();
217 let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
218 let sparse_ints =
219 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
220 assert_eq!(*sparse_ints.dtype(), DType::Primitive(PType::I32, Nullable));
221
222 let flat_ints = sparse_ints.to_primitive().unwrap();
223 let expected = PrimitiveArray::from_option_iter([
224 Some(0i32),
225 None,
226 fill_value,
227 fill_value,
228 fill_value,
229 fill_value,
230 fill_value,
231 Some(1),
232 fill_value,
233 fill_value,
234 ]);
235
236 assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
237 assert_eq!(flat_ints.validity(), expected.validity());
238
239 assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
240 assert!(flat_ints.validity().is_valid(0).unwrap());
241 assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
242 assert!(!flat_ints.validity().is_valid(1).unwrap());
243 assert_eq!(
244 flat_ints.as_slice::<i32>()[2],
245 fill_value.unwrap_or_default()
246 );
247 assert_eq!(
248 flat_ints.validity().is_valid(2).unwrap(),
249 fill_value.is_some()
250 );
251 assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
252 assert!(flat_ints.validity().is_valid(7).unwrap());
253 }
254
255 #[test]
256 fn test_sparse_struct_valid_fill() {
257 let field_names = FieldNames::from_iter(["a", "b"]);
258 let field_types = vec![
259 DType::Primitive(PType::I32, Nullable),
260 DType::Primitive(PType::I32, Nullable),
261 ];
262 let struct_fields = StructFields::new(field_names, field_types);
263 let struct_dtype = DType::Struct(struct_fields.clone(), Nullable);
264
265 let indices = buffer![0u64, 1, 7, 8].into_array();
266 let patch_values_a =
267 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(30)]).into_array();
268 let patch_values_b =
269 PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(3)]).into_array();
270 let patch_values = StructArray::try_new_with_dtype(
271 vec![patch_values_a, patch_values_b],
272 struct_fields.clone(),
273 4,
274 Validity::Array(
275 BoolArray::from_indices(4, vec![0, 1, 2], Validity::NonNullable).to_array(),
276 ),
277 )
278 .unwrap()
279 .into_array();
280
281 let fill_scalar = Scalar::struct_(
282 struct_dtype,
283 vec![Scalar::from(Some(-10i32)), Scalar::from(Some(-1i32))],
284 );
285 let len = 10;
286 let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap();
287
288 let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| {
289 if i == 0 {
290 Some(10)
291 } else if i == 1 {
292 None
293 } else if i == 7 {
294 Some(20)
295 } else {
296 Some(-10)
297 }
298 }));
299 let expected_b = PrimitiveArray::from_option_iter((0..len).map(|i| {
300 if i == 0 {
301 Some(1i32)
302 } else if i == 1 {
303 Some(2)
304 } else if i == 7 {
305 None
306 } else {
307 Some(-1)
308 }
309 }));
310
311 let expected = StructArray::try_new_with_dtype(
312 vec![expected_a.into_array(), expected_b.into_array()],
313 struct_fields,
314 len,
315 Validity::from_mask(Mask::from_excluded_indices(10, vec![8]), Nullable),
317 )
318 .unwrap()
319 .to_array()
320 .into_arrow_preferred()
321 .unwrap();
322
323 let actual = sparse_struct
324 .to_struct()
325 .unwrap()
326 .to_array()
327 .into_arrow_preferred()
328 .unwrap();
329
330 assert_eq!(expected.data_type(), actual.data_type());
331 assert_eq!(&expected, &actual);
332 }
333
334 #[test]
335 fn test_sparse_struct_invalid_fill() {
336 let field_names = FieldNames::from_iter(["a", "b"]);
337 let field_types = vec![
338 DType::Primitive(PType::I32, Nullable),
339 DType::Primitive(PType::I32, Nullable),
340 ];
341 let struct_fields = StructFields::new(field_names, field_types);
342 let struct_dtype = DType::Struct(struct_fields.clone(), Nullable);
343
344 let indices = buffer![0u64, 1, 7, 8].into_array();
345 let patch_values_a =
346 PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(30)]).into_array();
347 let patch_values_b =
348 PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(3)]).into_array();
349 let patch_values = StructArray::try_new_with_dtype(
350 vec![patch_values_a, patch_values_b],
351 struct_fields.clone(),
352 4,
353 Validity::Array(
354 BoolArray::from_indices(4, vec![0, 1, 2], Validity::NonNullable).to_array(),
355 ),
356 )
357 .unwrap()
358 .into_array();
359
360 let fill_scalar = Scalar::null(struct_dtype);
361 let len = 10;
362 let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap();
363
364 let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| {
365 if i == 0 {
366 Some(10)
367 } else if i == 1 {
368 None
369 } else if i == 7 {
370 Some(20)
371 } else {
372 Some(-10)
373 }
374 }));
375 let expected_b = PrimitiveArray::from_option_iter((0..len).map(|i| {
376 if i == 0 {
377 Some(1i32)
378 } else if i == 1 {
379 Some(2)
380 } else if i == 7 {
381 None
382 } else {
383 Some(-1)
384 }
385 }));
386
387 let expected = StructArray::try_new_with_dtype(
388 vec![expected_a.into_array(), expected_b.into_array()],
389 struct_fields,
390 len,
391 Validity::from_mask(Mask::from_indices(10, vec![0, 1, 7]), Nullable),
393 )
394 .unwrap()
395 .to_array()
396 .into_arrow_preferred()
397 .unwrap();
398
399 let actual = sparse_struct
400 .to_struct()
401 .unwrap()
402 .to_array()
403 .into_arrow_preferred()
404 .unwrap();
405
406 assert_eq!(expected.data_type(), actual.data_type());
407 assert_eq!(&expected, &actual);
408 }
409}