1use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28 type Array = SparseArray;
29 type Encoding = SparseEncoding;
30
31 type ArrayVTable = Self;
32 type CanonicalVTable = Self;
33 type OperationsVTable = Self;
34 type ValidityVTable = Self;
35 type VisitorVTable = Self;
36 type ComputeVTable = NotSupported;
37 type EncodeVTable = Self;
38 type SerdeVTable = Self;
39
40 fn id(_encoding: &Self::Encoding) -> EncodingId {
41 EncodingId::new_ref("vortex.sparse")
42 }
43
44 fn encoding(_array: &Self::Array) -> EncodingRef {
45 EncodingRef::new_ref(SparseEncoding.as_ref())
46 }
47}
48
49#[derive(Clone, Debug)]
50pub struct SparseArray {
51 patches: Patches,
52 fill_value: Scalar,
53 stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct SparseEncoding;
58
59impl SparseArray {
60 pub fn try_new(
61 indices: ArrayRef,
62 values: ArrayRef,
63 len: usize,
64 fill_value: Scalar,
65 ) -> VortexResult<Self> {
66 vortex_ensure!(
67 indices.len() == values.len(),
68 "Mismatched indices {} and values {} length",
69 indices.len(),
70 values.len()
71 );
72
73 vortex_ensure!(
74 indices.statistics().compute_is_strict_sorted() == Some(true),
75 "SparseArray: indices must be strict-sorted"
76 );
77
78 if !indices.is_empty() {
80 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
81
82 vortex_ensure!(
83 last_index < len,
84 "Array length was {len} but the last index is {last_index}"
85 );
86 }
87
88 let patches = Patches::new(len, 0, indices, values);
89
90 Ok(Self {
91 patches,
92 fill_value,
93 stats_set: Default::default(),
94 })
95 }
96
97 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
99 vortex_ensure!(
100 fill_value.dtype() == patches.values().dtype(),
101 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
102 fill_value,
103 patches.values().dtype(),
104 fill_value.dtype(),
105 );
106
107 Ok(Self {
108 patches,
109 fill_value,
110 stats_set: Default::default(),
111 })
112 }
113
114 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
115 Self {
116 patches,
117 fill_value,
118 stats_set: Default::default(),
119 }
120 }
121
122 #[inline]
123 pub fn patches(&self) -> &Patches {
124 &self.patches
125 }
126
127 #[inline]
128 pub fn resolved_patches(&self) -> VortexResult<Patches> {
129 let patches = self.patches();
130 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
131 let indices = sub_scalar(patches.indices(), indices_offset)?;
132 Ok(Patches::new(
133 patches.array_len(),
134 0,
135 indices,
136 patches.values().clone(),
137 ))
138 }
139
140 #[inline]
141 pub fn fill_scalar(&self) -> &Scalar {
142 &self.fill_value
143 }
144
145 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
149 if let Some(fill_value) = fill_value.as_ref()
150 && array.dtype() != fill_value.dtype()
151 {
152 vortex_bail!(
153 "Array and fill value types must match. got {} and {}",
154 array.dtype(),
155 fill_value.dtype()
156 )
157 }
158 let mask = array.validity_mask()?;
159
160 if mask.all_false() {
161 return Ok(
163 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
164 );
165 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
166 let non_null_values = filter(array, &mask)?;
168 let non_null_indices = match mask.indices() {
169 AllOr::All => {
170 unreachable!("Mask is mostly null")
172 }
173 AllOr::None => {
174 unreachable!("Mask is mostly null but not all null")
176 }
177 AllOr::Some(values) => {
178 let buffer: Buffer<u32> = values
179 .iter()
180 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
181 .collect();
182
183 buffer.into_array()
184 }
185 };
186
187 return Ok(SparseArray::try_new(
188 non_null_indices,
189 non_null_values,
190 array.len(),
191 Scalar::null(array.dtype().clone()),
192 )?
193 .into_array());
194 }
195
196 let fill = if let Some(fill) = fill_value {
197 fill
198 } else {
199 let (top_pvalue, _) = array
201 .to_primitive()?
202 .top_value()?
203 .vortex_expect("Non empty or all null array");
204
205 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
206 };
207
208 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
209 let non_top_mask = Mask::from_buffer(
210 fill_null(
211 &compare(array, &fill_array, Operator::NotEq)?,
212 &Scalar::bool(true, Nullability::NonNullable),
213 )?
214 .to_bool()?
215 .boolean_buffer()
216 .clone(),
217 );
218
219 let non_top_values = filter(array, &non_top_mask)?;
220
221 let indices: Buffer<u64> = match non_top_mask {
222 Mask::AllTrue(count) => {
223 (0u64..count as u64).collect()
225 }
226 Mask::AllFalse(_) => {
227 return Ok(fill_array);
229 }
230 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
231 };
232
233 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
234 .map(|a| a.into_array())
235 }
236}
237
238impl ArrayVTable<SparseVTable> for SparseVTable {
239 fn len(array: &SparseArray) -> usize {
240 array.patches.array_len()
241 }
242
243 fn dtype(array: &SparseArray) -> &DType {
244 array.fill_scalar().dtype()
245 }
246
247 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
248 array.stats_set.to_ref(array.as_ref())
249 }
250}
251
252impl ValidityVTable<SparseVTable> for SparseVTable {
253 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
254 Ok(match array.patches().get_patched(index) {
255 None => array.fill_scalar().is_valid(),
256 Some(patch_value) => patch_value.is_valid(),
257 })
258 }
259
260 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
261 if array.fill_scalar().is_null() {
262 return Ok(array.patches().values().len() == array.len()
264 && array.patches().values().all_valid()?);
265 }
266
267 array.patches().values().all_valid()
268 }
269
270 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
271 if !array.fill_scalar().is_null() {
272 return Ok(array.patches().values().len() == array.len()
274 && array.patches().values().all_invalid()?);
275 }
276
277 array.patches().values().all_invalid()
278 }
279
280 #[allow(clippy::unnecessary_fallible_conversions)]
281 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
282 let fill_is_valid = array.fill_scalar().is_valid();
283 let values_validity = array.patches().values().validity_mask()?;
284 let len = array.len();
285
286 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
287 return Ok(Mask::AllTrue(len));
288 }
289 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
290 return Ok(Mask::AllFalse(len));
291 }
292
293 let mut is_valid_buffer = BooleanBufferBuilder::new(len);
295 is_valid_buffer.append_n(len, fill_is_valid);
296
297 let indices = array.patches().indices().to_primitive()?;
298 let index_offset = array.patches().offset();
299
300 match_each_integer_ptype!(indices.ptype(), |I| {
301 let indices = indices.as_slice::<I>();
302 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
303 });
304
305 Ok(Mask::from_buffer(is_valid_buffer.finish()))
306 }
307}
308
309fn patch_validity<I: NativePType>(
310 is_valid_buffer: &mut BooleanBufferBuilder,
311 indices: &[I],
312 index_offset: usize,
313 values_validity: Mask,
314) {
315 let indices = indices.iter().map(|index| {
316 let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
317 index - index_offset
318 });
319 match values_validity {
320 Mask::AllTrue(_) => {
321 for index in indices {
322 is_valid_buffer.set_bit(index, true);
323 }
324 }
325 Mask::AllFalse(_) => {
326 for index in indices {
327 is_valid_buffer.set_bit(index, false);
328 }
329 }
330 Mask::Values(mask_values) => {
331 let is_valid = mask_values.boolean_buffer().iter();
332 for (index, is_valid) in indices.zip_eq(is_valid) {
333 is_valid_buffer.set_bit(index, is_valid);
334 }
335 }
336 }
337}
338
339#[cfg(test)]
340mod test {
341 use itertools::Itertools;
342 use vortex_array::IntoArray;
343 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
344 use vortex_array::compute::cast;
345 use vortex_array::validity::Validity;
346 use vortex_buffer::buffer;
347 use vortex_dtype::{DType, Nullability, PType};
348 use vortex_error::VortexUnwrap;
349 use vortex_scalar::{PrimitiveScalar, Scalar};
350
351 use super::*;
352
353 fn nullable_fill() -> Scalar {
354 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
355 }
356
357 fn non_nullable_fill() -> Scalar {
358 Scalar::from(42i32)
359 }
360
361 fn sparse_array(fill_value: Scalar) -> ArrayRef {
362 let mut values = buffer![100i32, 200, 300].into_array();
364 values = cast(&values, fill_value.dtype()).unwrap();
365
366 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
367 .unwrap()
368 .into_array()
369 }
370
371 #[test]
372 pub fn test_scalar_at() {
373 let array = sparse_array(nullable_fill());
374
375 assert_eq!(array.scalar_at(0), nullable_fill());
376 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
377 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
378 }
379
380 #[test]
381 #[should_panic(expected = "out of bounds")]
382 fn test_scalar_at_oob() {
383 let array = sparse_array(nullable_fill());
384 let _ = array.scalar_at(10);
385 }
386
387 #[test]
388 pub fn test_scalar_at_again() {
389 let arr = SparseArray::try_new(
390 ConstantArray::new(10u32, 1).into_array(),
391 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
392 100,
393 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
394 )
395 .unwrap();
396
397 assert_eq!(
398 PrimitiveScalar::try_from(&arr.scalar_at(10))
399 .unwrap()
400 .typed_value::<u32>(),
401 Some(1234)
402 );
403 assert!(arr.scalar_at(0).is_null());
404 assert!(arr.scalar_at(99).is_null());
405 }
406
407 #[test]
408 pub fn scalar_at_sliced() {
409 let sliced = sparse_array(nullable_fill()).slice(2, 7);
410 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
411 }
412
413 #[test]
414 pub fn validity_mask_sliced_null_fill() {
415 let sliced = sparse_array(nullable_fill()).slice(2, 7);
416 assert_eq!(
417 sliced.validity_mask().unwrap(),
418 Mask::from_iter(vec![true, false, false, true, false])
419 );
420 }
421
422 #[test]
423 pub fn validity_mask_sliced_nonnull_fill() {
424 let sliced = SparseArray::try_new(
425 buffer![2u64, 5, 8].into_array(),
426 ConstantArray::new(
427 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
428 3,
429 )
430 .into_array(),
431 10,
432 Scalar::primitive(1.0f32, Nullability::Nullable),
433 )
434 .unwrap()
435 .slice(2, 7);
436
437 assert_eq!(
438 sliced.validity_mask().unwrap(),
439 Mask::from_iter(vec![false, true, true, false, true])
440 );
441 }
442
443 #[test]
444 pub fn scalar_at_sliced_twice() {
445 let sliced_once = sparse_array(nullable_fill()).slice(1, 8);
446 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
447
448 let sliced_twice = sliced_once.slice(1, 6);
449 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
450 }
451
452 #[test]
453 pub fn sparse_validity_mask() {
454 let array = sparse_array(nullable_fill());
455 assert_eq!(
456 array
457 .validity_mask()
458 .unwrap()
459 .to_boolean_buffer()
460 .iter()
461 .collect_vec(),
462 [
463 false, false, true, false, false, true, false, false, true, false
464 ]
465 );
466 }
467
468 #[test]
469 fn sparse_validity_mask_non_null_fill() {
470 let array = sparse_array(non_nullable_fill());
471 assert!(array.validity_mask().unwrap().all_true());
472 }
473
474 #[test]
475 #[should_panic]
476 fn test_invalid_length() {
477 let values = buffer![15_u32, 135, 13531, 42].into_array();
478 let indices = buffer![10_u64, 11, 50, 100].into_array();
479
480 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
481 }
482
483 #[test]
484 fn test_valid_length() {
485 let values = buffer![15_u32, 135, 13531, 42].into_array();
486 let indices = buffer![10_u64, 11, 50, 100].into_array();
487
488 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
489 }
490
491 #[test]
492 fn encode_with_nulls() {
493 let sparse = SparseArray::encode(
494 &PrimitiveArray::new(
495 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
496 Validity::from_iter(vec![
497 true, true, false, true, false, true, false, true, true, false, true, false,
498 ]),
499 )
500 .into_array(),
501 None,
502 )
503 .vortex_unwrap();
504 let canonical = sparse.to_primitive().vortex_unwrap();
505 assert_eq!(
506 sparse.validity_mask().unwrap(),
507 Mask::from_iter(vec![
508 true, true, false, true, false, true, false, true, true, false, true, false,
509 ])
510 );
511 assert_eq!(
512 canonical.as_slice::<i32>(),
513 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
514 );
515 }
516
517 #[test]
518 fn validity_mask_includes_null_values_when_fill_is_null() {
519 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
520 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
521 .into_array();
522 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
523 let actual = array.validity_mask().unwrap();
524 let expected = Mask::from_iter([
525 true, false, true, false, false, false, false, false, true, false,
526 ]);
527
528 assert_eq!(actual, expected);
529 }
530}