1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Array, Canonical};
6use vortex_buffer::buffer;
7use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
8use vortex_error::{VortexError, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::{SparseArray, SparseVTable};
12
13impl CanonicalVTable<SparseVTable> for SparseVTable {
14 fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
15 let resolved_patches = array.resolved_patches()?;
16 if resolved_patches.num_patches() == 0 {
17 return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18 }
19
20 if matches!(array.dtype(), DType::Bool(_)) {
21 canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
22 } else {
23 let ptype = PType::try_from(resolved_patches.values().dtype())?;
24 match_each_native_ptype!(ptype, |$P| {
25 canonicalize_sparse_primitives::<$P>(
26 &resolved_patches,
27 &array.fill_scalar(),
28 )
29 })
30 }
31 }
32}
33
34fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
35 let (fill_bool, validity) = if fill_value.is_null() {
36 (false, Validity::AllInvalid)
37 } else {
38 (
39 fill_value.try_into()?,
40 if patches.dtype().nullability() == Nullability::NonNullable {
41 Validity::NonNullable
42 } else {
43 Validity::AllValid
44 },
45 )
46 };
47
48 let bools = BoolArray::new(
49 if fill_bool {
50 BooleanBuffer::new_set(patches.array_len())
51 } else {
52 BooleanBuffer::new_unset(patches.array_len())
53 },
54 validity,
55 );
56
57 bools.patch(patches).map(Canonical::Bool)
58}
59
60fn canonicalize_sparse_primitives<
61 T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
62>(
63 patches: &Patches,
64 fill_value: &Scalar,
65) -> VortexResult<Canonical> {
66 let (primitive_fill, validity) = if fill_value.is_null() {
67 (T::default(), Validity::AllInvalid)
68 } else {
69 (
70 fill_value.try_into()?,
71 if patches.dtype().nullability() == Nullability::NonNullable {
72 Validity::NonNullable
73 } else {
74 Validity::AllValid
75 },
76 )
77 };
78
79 let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
80
81 parray.patch(patches).map(Canonical::Primitive)
82}
83
84#[cfg(test)]
85mod test {
86 use rstest::rstest;
87 use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
88 use vortex_array::validity::Validity;
89 use vortex_array::vtable::ValidityHelper;
90 use vortex_array::{IntoArray, ToCanonical};
91 use vortex_buffer::buffer;
92 use vortex_dtype::{DType, Nullability, PType};
93 use vortex_scalar::Scalar;
94
95 use crate::SparseArray;
96
97 #[rstest]
98 #[case(Some(true))]
99 #[case(Some(false))]
100 #[case(None)]
101 fn test_sparse_bool(#[case] fill_value: Option<bool>) {
102 let indices = buffer![0u64, 1, 7].into_array();
103 let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
104 .into_array();
105 let sparse_bools =
106 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
107 assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
108
109 let flat_bools = sparse_bools.to_bool().unwrap();
110 let expected = bool_array_from_nullable_vec(
111 vec![
112 Some(true),
113 None,
114 fill_value,
115 fill_value,
116 fill_value,
117 fill_value,
118 fill_value,
119 Some(false),
120 fill_value,
121 fill_value,
122 ],
123 fill_value,
124 );
125
126 assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
127 assert_eq!(flat_bools.validity(), expected.validity());
128
129 assert!(flat_bools.boolean_buffer().value(0));
130 assert!(flat_bools.validity().is_valid(0).unwrap());
131 assert_eq!(
132 flat_bools.boolean_buffer().value(1),
133 fill_value.unwrap_or_default()
134 );
135 assert!(!flat_bools.validity().is_valid(1).unwrap());
136 assert_eq!(
137 flat_bools.validity().is_valid(2).unwrap(),
138 fill_value.is_some()
139 );
140 assert!(!flat_bools.boolean_buffer().value(7));
141 assert!(flat_bools.validity().is_valid(7).unwrap());
142 }
143
144 fn bool_array_from_nullable_vec(
145 bools: Vec<Option<bool>>,
146 fill_value: Option<bool>,
147 ) -> BoolArray {
148 let mut buffer = BooleanBufferBuilder::new(bools.len());
149 let mut validity = BooleanBufferBuilder::new(bools.len());
150 for maybe_bool in bools {
151 buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
152 validity.append(maybe_bool.is_some());
153 }
154 BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
155 }
156
157 #[rstest]
158 #[case(Some(0i32))]
159 #[case(Some(-1i32))]
160 #[case(None)]
161 fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
162 use vortex_scalar::Scalar;
163
164 let indices = buffer![0u64, 1, 7].into_array();
165 let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
166 let sparse_ints =
167 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
168 assert_eq!(
169 *sparse_ints.dtype(),
170 DType::Primitive(PType::I32, Nullability::Nullable)
171 );
172
173 let flat_ints = sparse_ints.to_primitive().unwrap();
174 let expected = PrimitiveArray::from_option_iter([
175 Some(0i32),
176 None,
177 fill_value,
178 fill_value,
179 fill_value,
180 fill_value,
181 fill_value,
182 Some(1),
183 fill_value,
184 fill_value,
185 ]);
186
187 assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
188 assert_eq!(flat_ints.validity(), expected.validity());
189
190 assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
191 assert!(flat_ints.validity().is_valid(0).unwrap());
192 assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
193 assert!(!flat_ints.validity().is_valid(1).unwrap());
194 assert_eq!(
195 flat_ints.as_slice::<i32>()[2],
196 fill_value.unwrap_or_default()
197 );
198 assert_eq!(
199 flat_ints.validity().is_valid(2).unwrap(),
200 fill_value.is_some()
201 );
202 assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
203 assert!(flat_ints.validity().is_valid(7).unwrap());
204 }
205}