1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Array, Canonical};
6use vortex_buffer::buffer;
7use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
8use vortex_error::{VortexError, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::{SparseArray, SparseVTable};
12
13impl CanonicalVTable<SparseVTable> for SparseVTable {
14 fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
15 let resolved_patches = array.resolved_patches()?;
16 if resolved_patches.num_patches() == 0 {
17 return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18 }
19
20 if matches!(array.dtype(), DType::Bool(_)) {
21 canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
22 } else {
23 let ptype = PType::try_from(resolved_patches.values().dtype())?;
24 match_each_native_ptype!(ptype, |P| {
25 canonicalize_sparse_primitives::<P>(&resolved_patches, array.fill_scalar())
26 })
27 }
28 }
29}
30
31fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
32 let (fill_bool, validity) = if fill_value.is_null() {
33 (false, Validity::AllInvalid)
34 } else {
35 (
36 fill_value.try_into()?,
37 if patches.dtype().nullability() == Nullability::NonNullable {
38 Validity::NonNullable
39 } else {
40 Validity::AllValid
41 },
42 )
43 };
44
45 let bools = BoolArray::new(
46 if fill_bool {
47 BooleanBuffer::new_set(patches.array_len())
48 } else {
49 BooleanBuffer::new_unset(patches.array_len())
50 },
51 validity,
52 );
53
54 bools.patch(patches).map(Canonical::Bool)
55}
56
57fn canonicalize_sparse_primitives<
58 T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
59>(
60 patches: &Patches,
61 fill_value: &Scalar,
62) -> VortexResult<Canonical> {
63 let (primitive_fill, validity) = if fill_value.is_null() {
64 (T::default(), Validity::AllInvalid)
65 } else {
66 (
67 fill_value.try_into()?,
68 if patches.dtype().nullability() == Nullability::NonNullable {
69 Validity::NonNullable
70 } else {
71 Validity::AllValid
72 },
73 )
74 };
75
76 let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
77
78 parray.patch(patches).map(Canonical::Primitive)
79}
80
81#[cfg(test)]
82mod test {
83 use rstest::rstest;
84 use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
85 use vortex_array::validity::Validity;
86 use vortex_array::vtable::ValidityHelper;
87 use vortex_array::{IntoArray, ToCanonical};
88 use vortex_buffer::buffer;
89 use vortex_dtype::{DType, Nullability, PType};
90 use vortex_scalar::Scalar;
91
92 use crate::SparseArray;
93
94 #[rstest]
95 #[case(Some(true))]
96 #[case(Some(false))]
97 #[case(None)]
98 fn test_sparse_bool(#[case] fill_value: Option<bool>) {
99 let indices = buffer![0u64, 1, 7].into_array();
100 let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
101 .into_array();
102 let sparse_bools =
103 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
104 assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
105
106 let flat_bools = sparse_bools.to_bool().unwrap();
107 let expected = bool_array_from_nullable_vec(
108 vec![
109 Some(true),
110 None,
111 fill_value,
112 fill_value,
113 fill_value,
114 fill_value,
115 fill_value,
116 Some(false),
117 fill_value,
118 fill_value,
119 ],
120 fill_value,
121 );
122
123 assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
124 assert_eq!(flat_bools.validity(), expected.validity());
125
126 assert!(flat_bools.boolean_buffer().value(0));
127 assert!(flat_bools.validity().is_valid(0).unwrap());
128 assert_eq!(
129 flat_bools.boolean_buffer().value(1),
130 fill_value.unwrap_or_default()
131 );
132 assert!(!flat_bools.validity().is_valid(1).unwrap());
133 assert_eq!(
134 flat_bools.validity().is_valid(2).unwrap(),
135 fill_value.is_some()
136 );
137 assert!(!flat_bools.boolean_buffer().value(7));
138 assert!(flat_bools.validity().is_valid(7).unwrap());
139 }
140
141 fn bool_array_from_nullable_vec(
142 bools: Vec<Option<bool>>,
143 fill_value: Option<bool>,
144 ) -> BoolArray {
145 let mut buffer = BooleanBufferBuilder::new(bools.len());
146 let mut validity = BooleanBufferBuilder::new(bools.len());
147 for maybe_bool in bools {
148 buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
149 validity.append(maybe_bool.is_some());
150 }
151 BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
152 }
153
154 #[rstest]
155 #[case(Some(0i32))]
156 #[case(Some(-1i32))]
157 #[case(None)]
158 fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
159 use vortex_scalar::Scalar;
160
161 let indices = buffer![0u64, 1, 7].into_array();
162 let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
163 let sparse_ints =
164 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
165 assert_eq!(
166 *sparse_ints.dtype(),
167 DType::Primitive(PType::I32, Nullability::Nullable)
168 );
169
170 let flat_ints = sparse_ints.to_primitive().unwrap();
171 let expected = PrimitiveArray::from_option_iter([
172 Some(0i32),
173 None,
174 fill_value,
175 fill_value,
176 fill_value,
177 fill_value,
178 fill_value,
179 Some(1),
180 fill_value,
181 fill_value,
182 ]);
183
184 assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
185 assert_eq!(flat_ints.validity(), expected.validity());
186
187 assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
188 assert!(flat_ints.validity().is_valid(0).unwrap());
189 assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
190 assert!(!flat_ints.validity().is_valid(1).unwrap());
191 assert_eq!(
192 flat_ints.as_slice::<i32>()[2],
193 fill_value.unwrap_or_default()
194 );
195 assert_eq!(
196 flat_ints.validity().is_valid(2).unwrap(),
197 fill_value.is_some()
198 );
199 assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
200 assert!(flat_ints.validity().is_valid(7).unwrap());
201 }
202}