vortex_sparse/
canonical.rs

1use itertools::Itertools;
2use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, StructArray};
3use vortex_array::patches::Patches;
4use vortex_array::validity::Validity;
5use vortex_array::vtable::CanonicalVTable;
6use vortex_array::{Array, Canonical};
7use vortex_buffer::buffer;
8use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::{SparseArray, SparseVTable};
13
14impl CanonicalVTable<SparseVTable> for SparseVTable {
15    fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
16        if array.patches().num_patches() == 0 {
17            return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18        }
19
20        match array.dtype() {
21            DType::Bool(..) => {
22                let resolved_patches = array.resolved_patches()?;
23                canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
24            }
25            DType::Primitive(ptype, ..) => {
26                let resolved_patches = array.resolved_patches()?;
27                match_each_native_ptype!(ptype, |P| {
28                    canonicalize_sparse_primitives::<P>(&resolved_patches, array.fill_scalar())
29                })
30            }
31            DType::Struct(struct_fields, ..) => {
32                let fill_struct = array.fill_scalar().as_struct();
33                let (fill_values, top_level_fill_validity) = match fill_struct.fields() {
34                    Some(fill_values) => (fill_values, Validity::AllValid),
35                    None => (
36                        struct_fields.fields().map(Scalar::default_value).collect(),
37                        Validity::AllInvalid,
38                    ),
39                };
40                // Resolution is unnecessary b/c we're just pushing the patches into the fields.
41                let patches = array.patches();
42                let patch_values_as_struct = patches.values().to_canonical()?.into_struct()?;
43                let columns_patch_values = patch_values_as_struct.fields();
44                let names = patch_values_as_struct.names();
45                let validity = if array.dtype().is_nullable() {
46                    top_level_fill_validity.patch(
47                        array.len(),
48                        patches.offset(),
49                        patches.indices(),
50                        &Validity::from_mask(
51                            patches.values().validity_mask()?,
52                            Nullability::Nullable,
53                        ),
54                    )?
55                } else {
56                    top_level_fill_validity.into_non_nullable().ok_or_else(|| {
57                        vortex_err!("fill validity should match sparse array nullability")
58                    })?
59                };
60                columns_patch_values
61                    .iter()
62                    .cloned()
63                    .zip_eq(fill_values.into_iter())
64                    .map(|(patch_values, fill_value)| -> VortexResult<_> {
65                        SparseArray::try_new_from_patches(
66                            patches.clone().map_values(|_| Ok(patch_values))?,
67                            fill_value,
68                        )
69                    })
70                    .process_results(|sparse_columns| {
71                        StructArray::try_from_iter_with_validity(
72                            names.iter().zip_eq(sparse_columns),
73                            validity,
74                        )
75                        .map(Canonical::Struct)
76                    })?
77            }
78
79            dtype => vortex_bail!("unsupported type: {}", dtype),
80        }
81    }
82}
83
84fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
85    let (fill_bool, validity) = if fill_value.is_null() {
86        (false, Validity::AllInvalid)
87    } else {
88        (
89            fill_value.try_into()?,
90            if patches.dtype().nullability() == Nullability::NonNullable {
91                Validity::NonNullable
92            } else {
93                Validity::AllValid
94            },
95        )
96    };
97
98    let bools = BoolArray::new(
99        if fill_bool {
100            BooleanBuffer::new_set(patches.array_len())
101        } else {
102            BooleanBuffer::new_unset(patches.array_len())
103        },
104        validity,
105    );
106
107    bools.patch(patches).map(Canonical::Bool)
108}
109
110fn canonicalize_sparse_primitives<
111    T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
112>(
113    patches: &Patches,
114    fill_value: &Scalar,
115) -> VortexResult<Canonical> {
116    let (primitive_fill, validity) = if fill_value.is_null() {
117        (T::default(), Validity::AllInvalid)
118    } else {
119        (
120            fill_value.try_into()?,
121            if patches.dtype().nullability() == Nullability::NonNullable {
122                Validity::NonNullable
123            } else {
124                Validity::AllValid
125            },
126        )
127    };
128
129    let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
130
131    parray.patch(patches).map(Canonical::Primitive)
132}
133
134#[cfg(test)]
135mod test {
136
137    use rstest::rstest;
138    use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray, StructArray};
139    use vortex_array::arrow::IntoArrowArray as _;
140    use vortex_array::validity::Validity;
141    use vortex_array::vtable::ValidityHelper;
142    use vortex_array::{IntoArray, ToCanonical};
143    use vortex_buffer::buffer;
144    use vortex_dtype::Nullability::Nullable;
145    use vortex_dtype::{DType, FieldNames, PType, StructFields};
146    use vortex_mask::Mask;
147    use vortex_scalar::Scalar;
148
149    use crate::SparseArray;
150
151    #[rstest]
152    #[case(Some(true))]
153    #[case(Some(false))]
154    #[case(None)]
155    fn test_sparse_bool(#[case] fill_value: Option<bool>) {
156        let indices = buffer![0u64, 1, 7].into_array();
157        let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
158            .into_array();
159        let sparse_bools =
160            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
161        assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullable));
162
163        let flat_bools = sparse_bools.to_bool().unwrap();
164        let expected = bool_array_from_nullable_vec(
165            vec![
166                Some(true),
167                None,
168                fill_value,
169                fill_value,
170                fill_value,
171                fill_value,
172                fill_value,
173                Some(false),
174                fill_value,
175                fill_value,
176            ],
177            fill_value,
178        );
179
180        assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
181        assert_eq!(flat_bools.validity(), expected.validity());
182
183        assert!(flat_bools.boolean_buffer().value(0));
184        assert!(flat_bools.validity().is_valid(0).unwrap());
185        assert_eq!(
186            flat_bools.boolean_buffer().value(1),
187            fill_value.unwrap_or_default()
188        );
189        assert!(!flat_bools.validity().is_valid(1).unwrap());
190        assert_eq!(
191            flat_bools.validity().is_valid(2).unwrap(),
192            fill_value.is_some()
193        );
194        assert!(!flat_bools.boolean_buffer().value(7));
195        assert!(flat_bools.validity().is_valid(7).unwrap());
196    }
197
198    fn bool_array_from_nullable_vec(
199        bools: Vec<Option<bool>>,
200        fill_value: Option<bool>,
201    ) -> BoolArray {
202        let mut buffer = BooleanBufferBuilder::new(bools.len());
203        let mut validity = BooleanBufferBuilder::new(bools.len());
204        for maybe_bool in bools {
205            buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
206            validity.append(maybe_bool.is_some());
207        }
208        BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
209    }
210
211    #[rstest]
212    #[case(Some(0i32))]
213    #[case(Some(-1i32))]
214    #[case(None)]
215    fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
216        let indices = buffer![0u64, 1, 7].into_array();
217        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
218        let sparse_ints =
219            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
220        assert_eq!(*sparse_ints.dtype(), DType::Primitive(PType::I32, Nullable));
221
222        let flat_ints = sparse_ints.to_primitive().unwrap();
223        let expected = PrimitiveArray::from_option_iter([
224            Some(0i32),
225            None,
226            fill_value,
227            fill_value,
228            fill_value,
229            fill_value,
230            fill_value,
231            Some(1),
232            fill_value,
233            fill_value,
234        ]);
235
236        assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
237        assert_eq!(flat_ints.validity(), expected.validity());
238
239        assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
240        assert!(flat_ints.validity().is_valid(0).unwrap());
241        assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
242        assert!(!flat_ints.validity().is_valid(1).unwrap());
243        assert_eq!(
244            flat_ints.as_slice::<i32>()[2],
245            fill_value.unwrap_or_default()
246        );
247        assert_eq!(
248            flat_ints.validity().is_valid(2).unwrap(),
249            fill_value.is_some()
250        );
251        assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
252        assert!(flat_ints.validity().is_valid(7).unwrap());
253    }
254
255    #[test]
256    fn test_sparse_struct_valid_fill() {
257        let field_names = FieldNames::from_iter(["a", "b"]);
258        let field_types = vec![
259            DType::Primitive(PType::I32, Nullable),
260            DType::Primitive(PType::I32, Nullable),
261        ];
262        let struct_fields = StructFields::new(field_names, field_types);
263        let struct_dtype = DType::Struct(struct_fields.clone(), Nullable);
264
265        let indices = buffer![0u64, 1, 7, 8].into_array();
266        let patch_values_a =
267            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(30)]).into_array();
268        let patch_values_b =
269            PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(3)]).into_array();
270        let patch_values = StructArray::try_new_with_dtype(
271            vec![patch_values_a, patch_values_b],
272            struct_fields.clone(),
273            4,
274            Validity::Array(
275                BoolArray::from_indices(4, vec![0, 1, 2], Validity::NonNullable).to_array(),
276            ),
277        )
278        .unwrap()
279        .into_array();
280
281        let fill_scalar = Scalar::struct_(
282            struct_dtype,
283            vec![Scalar::from(Some(-10i32)), Scalar::from(Some(-1i32))],
284        );
285        let len = 10;
286        let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap();
287
288        let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| {
289            if i == 0 {
290                Some(10)
291            } else if i == 1 {
292                None
293            } else if i == 7 {
294                Some(20)
295            } else {
296                Some(-10)
297            }
298        }));
299        let expected_b = PrimitiveArray::from_option_iter((0..len).map(|i| {
300            if i == 0 {
301                Some(1i32)
302            } else if i == 1 {
303                Some(2)
304            } else if i == 7 {
305                None
306            } else {
307                Some(-1)
308            }
309        }));
310
311        let expected = StructArray::try_new_with_dtype(
312            vec![expected_a.into_array(), expected_b.into_array()],
313            struct_fields,
314            len,
315            // NB: patch indices: [0, 1, 7, 8]; patch validity: [Valid, Valid, Valid, Invalid]; ergo 8 is Invalid.
316            Validity::from_mask(Mask::from_excluded_indices(10, vec![8]), Nullable),
317        )
318        .unwrap()
319        .to_array()
320        .into_arrow_preferred()
321        .unwrap();
322
323        let actual = sparse_struct
324            .to_struct()
325            .unwrap()
326            .to_array()
327            .into_arrow_preferred()
328            .unwrap();
329
330        assert_eq!(expected.data_type(), actual.data_type());
331        assert_eq!(&expected, &actual);
332    }
333
334    #[test]
335    fn test_sparse_struct_invalid_fill() {
336        let field_names = FieldNames::from_iter(["a", "b"]);
337        let field_types = vec![
338            DType::Primitive(PType::I32, Nullable),
339            DType::Primitive(PType::I32, Nullable),
340        ];
341        let struct_fields = StructFields::new(field_names, field_types);
342        let struct_dtype = DType::Struct(struct_fields.clone(), Nullable);
343
344        let indices = buffer![0u64, 1, 7, 8].into_array();
345        let patch_values_a =
346            PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(30)]).into_array();
347        let patch_values_b =
348            PrimitiveArray::from_option_iter([Some(1i32), Some(2), None, Some(3)]).into_array();
349        let patch_values = StructArray::try_new_with_dtype(
350            vec![patch_values_a, patch_values_b],
351            struct_fields.clone(),
352            4,
353            Validity::Array(
354                BoolArray::from_indices(4, vec![0, 1, 2], Validity::NonNullable).to_array(),
355            ),
356        )
357        .unwrap()
358        .into_array();
359
360        let fill_scalar = Scalar::null(struct_dtype);
361        let len = 10;
362        let sparse_struct = SparseArray::try_new(indices, patch_values, len, fill_scalar).unwrap();
363
364        let expected_a = PrimitiveArray::from_option_iter((0..len).map(|i| {
365            if i == 0 {
366                Some(10)
367            } else if i == 1 {
368                None
369            } else if i == 7 {
370                Some(20)
371            } else {
372                Some(-10)
373            }
374        }));
375        let expected_b = PrimitiveArray::from_option_iter((0..len).map(|i| {
376            if i == 0 {
377                Some(1i32)
378            } else if i == 1 {
379                Some(2)
380            } else if i == 7 {
381                None
382            } else {
383                Some(-1)
384            }
385        }));
386
387        let expected = StructArray::try_new_with_dtype(
388            vec![expected_a.into_array(), expected_b.into_array()],
389            struct_fields,
390            len,
391            // NB: patch indices: [0, 1, 7, 8]; patch validity: [Valid, Valid, Valid, Invalid]; ergo 0, 1, 7 are valid.
392            Validity::from_mask(Mask::from_indices(10, vec![0, 1, 7]), Nullable),
393        )
394        .unwrap()
395        .to_array()
396        .into_arrow_preferred()
397        .unwrap();
398
399        let actual = sparse_struct
400            .to_struct()
401            .unwrap()
402            .to_array()
403            .into_arrow_preferred()
404            .unwrap();
405
406        assert_eq!(expected.data_type(), actual.data_type());
407        assert_eq!(&expected, &actual);
408    }
409}