vortex_sparse/
canonical.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Array, Canonical};
6use vortex_buffer::buffer;
7use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
8use vortex_error::{VortexError, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::{SparseArray, SparseVTable};
12
13impl CanonicalVTable<SparseVTable> for SparseVTable {
14    fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
15        let resolved_patches = array.resolved_patches()?;
16        if resolved_patches.num_patches() == 0 {
17            return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18        }
19
20        if matches!(array.dtype(), DType::Bool(_)) {
21            canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
22        } else {
23            let ptype = PType::try_from(resolved_patches.values().dtype())?;
24            match_each_native_ptype!(ptype, |P| {
25                canonicalize_sparse_primitives::<P>(&resolved_patches, array.fill_scalar())
26            })
27        }
28    }
29}
30
31fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
32    let (fill_bool, validity) = if fill_value.is_null() {
33        (false, Validity::AllInvalid)
34    } else {
35        (
36            fill_value.try_into()?,
37            if patches.dtype().nullability() == Nullability::NonNullable {
38                Validity::NonNullable
39            } else {
40                Validity::AllValid
41            },
42        )
43    };
44
45    let bools = BoolArray::new(
46        if fill_bool {
47            BooleanBuffer::new_set(patches.array_len())
48        } else {
49            BooleanBuffer::new_unset(patches.array_len())
50        },
51        validity,
52    );
53
54    bools.patch(patches).map(Canonical::Bool)
55}
56
57fn canonicalize_sparse_primitives<
58    T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
59>(
60    patches: &Patches,
61    fill_value: &Scalar,
62) -> VortexResult<Canonical> {
63    let (primitive_fill, validity) = if fill_value.is_null() {
64        (T::default(), Validity::AllInvalid)
65    } else {
66        (
67            fill_value.try_into()?,
68            if patches.dtype().nullability() == Nullability::NonNullable {
69                Validity::NonNullable
70            } else {
71                Validity::AllValid
72            },
73        )
74    };
75
76    let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
77
78    parray.patch(patches).map(Canonical::Primitive)
79}
80
81#[cfg(test)]
82mod test {
83    use rstest::rstest;
84    use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
85    use vortex_array::validity::Validity;
86    use vortex_array::vtable::ValidityHelper;
87    use vortex_array::{IntoArray, ToCanonical};
88    use vortex_buffer::buffer;
89    use vortex_dtype::{DType, Nullability, PType};
90    use vortex_scalar::Scalar;
91
92    use crate::SparseArray;
93
94    #[rstest]
95    #[case(Some(true))]
96    #[case(Some(false))]
97    #[case(None)]
98    fn test_sparse_bool(#[case] fill_value: Option<bool>) {
99        let indices = buffer![0u64, 1, 7].into_array();
100        let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
101            .into_array();
102        let sparse_bools =
103            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
104        assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
105
106        let flat_bools = sparse_bools.to_bool().unwrap();
107        let expected = bool_array_from_nullable_vec(
108            vec![
109                Some(true),
110                None,
111                fill_value,
112                fill_value,
113                fill_value,
114                fill_value,
115                fill_value,
116                Some(false),
117                fill_value,
118                fill_value,
119            ],
120            fill_value,
121        );
122
123        assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
124        assert_eq!(flat_bools.validity(), expected.validity());
125
126        assert!(flat_bools.boolean_buffer().value(0));
127        assert!(flat_bools.validity().is_valid(0).unwrap());
128        assert_eq!(
129            flat_bools.boolean_buffer().value(1),
130            fill_value.unwrap_or_default()
131        );
132        assert!(!flat_bools.validity().is_valid(1).unwrap());
133        assert_eq!(
134            flat_bools.validity().is_valid(2).unwrap(),
135            fill_value.is_some()
136        );
137        assert!(!flat_bools.boolean_buffer().value(7));
138        assert!(flat_bools.validity().is_valid(7).unwrap());
139    }
140
141    fn bool_array_from_nullable_vec(
142        bools: Vec<Option<bool>>,
143        fill_value: Option<bool>,
144    ) -> BoolArray {
145        let mut buffer = BooleanBufferBuilder::new(bools.len());
146        let mut validity = BooleanBufferBuilder::new(bools.len());
147        for maybe_bool in bools {
148            buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
149            validity.append(maybe_bool.is_some());
150        }
151        BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
152    }
153
154    #[rstest]
155    #[case(Some(0i32))]
156    #[case(Some(-1i32))]
157    #[case(None)]
158    fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
159        use vortex_scalar::Scalar;
160
161        let indices = buffer![0u64, 1, 7].into_array();
162        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
163        let sparse_ints =
164            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
165        assert_eq!(
166            *sparse_ints.dtype(),
167            DType::Primitive(PType::I32, Nullability::Nullable)
168        );
169
170        let flat_ints = sparse_ints.to_primitive().unwrap();
171        let expected = PrimitiveArray::from_option_iter([
172            Some(0i32),
173            None,
174            fill_value,
175            fill_value,
176            fill_value,
177            fill_value,
178            fill_value,
179            Some(1),
180            fill_value,
181            fill_value,
182        ]);
183
184        assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
185        assert_eq!(flat_ints.validity(), expected.validity());
186
187        assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
188        assert!(flat_ints.validity().is_valid(0).unwrap());
189        assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
190        assert!(!flat_ints.validity().is_valid(1).unwrap());
191        assert_eq!(
192            flat_ints.as_slice::<i32>()[2],
193            fill_value.unwrap_or_default()
194        );
195        assert_eq!(
196            flat_ints.validity().is_valid(2).unwrap(),
197            fill_value.is_some()
198        );
199        assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
200        assert!(flat_ints.validity().is_valid(7).unwrap());
201    }
202}