1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
11use vortex_array::patches::Patches;
12use vortex_array::stats::{ArrayStats, StatsSetRef};
13use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
14use vortex_array::{
15 Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, IntoArray, Precision,
16 ToCanonical, vtable,
17};
18use vortex_buffer::{BitBufferMut, Buffer};
19use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
20use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
21use vortex_mask::{AllOr, Mask};
22use vortex_scalar::Scalar;
23
24mod canonical;
25mod compute;
26mod ops;
27mod serde;
28
29vtable!(Sparse);
30
31impl VTable for SparseVTable {
32 type Array = SparseArray;
33 type Encoding = SparseEncoding;
34
35 type ArrayVTable = Self;
36 type CanonicalVTable = Self;
37 type OperationsVTable = Self;
38 type ValidityVTable = Self;
39 type VisitorVTable = Self;
40 type ComputeVTable = NotSupported;
41 type EncodeVTable = Self;
42 type SerdeVTable = Self;
43 type OperatorVTable = NotSupported;
44
45 fn id(_encoding: &Self::Encoding) -> EncodingId {
46 EncodingId::new_ref("vortex.sparse")
47 }
48
49 fn encoding(_array: &Self::Array) -> EncodingRef {
50 EncodingRef::new_ref(SparseEncoding.as_ref())
51 }
52}
53
54#[derive(Clone, Debug)]
55pub struct SparseArray {
56 patches: Patches,
57 fill_value: Scalar,
58 stats_set: ArrayStats,
59}
60
61#[derive(Clone, Debug)]
62pub struct SparseEncoding;
63
64impl SparseArray {
65 pub fn try_new(
66 indices: ArrayRef,
67 values: ArrayRef,
68 len: usize,
69 fill_value: Scalar,
70 ) -> VortexResult<Self> {
71 vortex_ensure!(
72 indices.len() == values.len(),
73 "Mismatched indices {} and values {} length",
74 indices.len(),
75 values.len()
76 );
77
78 vortex_ensure!(
79 indices.statistics().compute_is_strict_sorted() == Some(true),
80 "SparseArray: indices must be strict-sorted"
81 );
82
83 if !indices.is_empty() {
85 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
86
87 vortex_ensure!(
88 last_index < len,
89 "Array length was {len} but the last index is {last_index}"
90 );
91 }
92
93 Ok(Self {
94 patches: Patches::new(len, 0, indices, values, None),
96 fill_value,
97 stats_set: Default::default(),
98 })
99 }
100
101 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
103 vortex_ensure!(
104 fill_value.dtype() == patches.values().dtype(),
105 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
106 fill_value,
107 patches.values().dtype(),
108 fill_value.dtype(),
109 );
110
111 Ok(Self {
112 patches,
113 fill_value,
114 stats_set: Default::default(),
115 })
116 }
117
118 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
119 Self {
120 patches,
121 fill_value,
122 stats_set: Default::default(),
123 }
124 }
125
126 #[inline]
127 pub fn patches(&self) -> &Patches {
128 &self.patches
129 }
130
131 #[inline]
132 pub fn resolved_patches(&self) -> Patches {
133 let patches = self.patches();
134 let indices_offset = Scalar::from(patches.offset())
135 .cast(patches.indices().dtype())
136 .vortex_expect("Patches offset must cast to the indices dtype");
137 let indices = sub_scalar(patches.indices(), indices_offset)
138 .vortex_expect("must be able to subtract offset from indices");
139
140 Patches::new(
141 patches.array_len(),
142 0,
143 indices,
144 patches.values().clone(),
145 None,
147 )
148 }
149
150 #[inline]
151 pub fn fill_scalar(&self) -> &Scalar {
152 &self.fill_value
153 }
154
155 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
159 if let Some(fill_value) = fill_value.as_ref()
160 && array.dtype() != fill_value.dtype()
161 {
162 vortex_bail!(
163 "Array and fill value types must match. got {} and {}",
164 array.dtype(),
165 fill_value.dtype()
166 )
167 }
168 let mask = array.validity_mask();
169
170 if mask.all_false() {
171 return Ok(
173 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
174 );
175 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
176 let non_null_values = filter(array, &mask)?;
178 let non_null_indices = match mask.indices() {
179 AllOr::All => {
180 unreachable!("Mask is mostly null")
182 }
183 AllOr::None => {
184 unreachable!("Mask is mostly null but not all null")
186 }
187 AllOr::Some(values) => {
188 let buffer: Buffer<u32> = values
189 .iter()
190 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
191 .collect();
192
193 buffer.into_array()
194 }
195 };
196
197 return Ok(SparseArray::try_new(
198 non_null_indices,
199 non_null_values,
200 array.len(),
201 Scalar::null(array.dtype().clone()),
202 )?
203 .into_array());
204 }
205
206 let fill = if let Some(fill) = fill_value {
207 fill
208 } else {
209 let (top_pvalue, _) = array
211 .to_primitive()
212 .top_value()?
213 .vortex_expect("Non empty or all null array");
214
215 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
216 };
217
218 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
219 let non_top_mask = Mask::from_buffer(
220 fill_null(
221 &compare(array, &fill_array, Operator::NotEq)?,
222 &Scalar::bool(true, Nullability::NonNullable),
223 )?
224 .to_bool()
225 .bit_buffer()
226 .clone(),
227 );
228
229 let non_top_values = filter(array, &non_top_mask)?;
230
231 let indices: Buffer<u64> = match non_top_mask {
232 Mask::AllTrue(count) => {
233 (0u64..count as u64).collect()
235 }
236 Mask::AllFalse(_) => {
237 return Ok(fill_array);
239 }
240 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
241 };
242
243 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
244 .map(|a| a.into_array())
245 }
246}
247
248impl ArrayVTable<SparseVTable> for SparseVTable {
249 fn len(array: &SparseArray) -> usize {
250 array.patches.array_len()
251 }
252
253 fn dtype(array: &SparseArray) -> &DType {
254 array.fill_scalar().dtype()
255 }
256
257 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
258 array.stats_set.to_ref(array.as_ref())
259 }
260
261 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
262 array.patches.array_hash(state, precision);
263 array.fill_value.hash(state);
264 }
265
266 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
267 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
268 }
269}
270
271impl ValidityVTable<SparseVTable> for SparseVTable {
272 fn is_valid(array: &SparseArray, index: usize) -> bool {
273 match array.patches().get_patched(index) {
274 None => array.fill_scalar().is_valid(),
275 Some(patch_value) => patch_value.is_valid(),
276 }
277 }
278
279 fn all_valid(array: &SparseArray) -> bool {
280 if array.fill_scalar().is_null() {
281 return array.patches().values().len() == array.len()
283 && array.patches().values().all_valid();
284 }
285
286 array.patches().values().all_valid()
287 }
288
289 fn all_invalid(array: &SparseArray) -> bool {
290 if !array.fill_scalar().is_null() {
291 return array.patches().values().len() == array.len()
293 && array.patches().values().all_invalid();
294 }
295
296 array.patches().values().all_invalid()
297 }
298
299 #[allow(clippy::unnecessary_fallible_conversions)]
300 fn validity_mask(array: &SparseArray) -> Mask {
301 let fill_is_valid = array.fill_scalar().is_valid();
302 let values_validity = array.patches().values().validity_mask();
303 let len = array.len();
304
305 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
306 return Mask::AllTrue(len);
307 }
308 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
309 return Mask::AllFalse(len);
310 }
311
312 let mut is_valid_buffer = if fill_is_valid {
313 BitBufferMut::new_set(len)
314 } else {
315 BitBufferMut::new_unset(len)
316 };
317
318 let indices = array.patches().indices().to_primitive();
319 let index_offset = array.patches().offset();
320
321 match_each_integer_ptype!(indices.ptype(), |I| {
322 let indices = indices.as_slice::<I>();
323 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
324 });
325
326 Mask::from_buffer(is_valid_buffer.freeze())
327 }
328}
329
330fn patch_validity<I: NativePType + AsPrimitive<usize>>(
331 is_valid_buffer: &mut BitBufferMut,
332 indices: &[I],
333 index_offset: usize,
334 values_validity: Mask,
335) {
336 let indices = indices.iter().map(|index| index.as_() - index_offset);
337 match values_validity {
338 Mask::AllTrue(_) => {
339 for index in indices {
340 is_valid_buffer.set(index);
341 }
342 }
343 Mask::AllFalse(_) => {
344 for index in indices {
345 is_valid_buffer.unset(index);
346 }
347 }
348 Mask::Values(mask_values) => {
349 let is_valid = mask_values.bit_buffer().iter();
350 for (index, is_valid) in indices.zip_eq(is_valid) {
351 is_valid_buffer.set_to(index, is_valid);
352 }
353 }
354 }
355}
356
357#[cfg(test)]
358mod test {
359 use itertools::Itertools;
360 use vortex_array::IntoArray;
361 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
362 use vortex_array::compute::cast;
363 use vortex_array::validity::Validity;
364 use vortex_buffer::buffer;
365 use vortex_dtype::{DType, Nullability, PType};
366 use vortex_error::VortexUnwrap;
367 use vortex_scalar::{PrimitiveScalar, Scalar};
368
369 use super::*;
370
371 fn nullable_fill() -> Scalar {
372 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
373 }
374
375 fn non_nullable_fill() -> Scalar {
376 Scalar::from(42i32)
377 }
378
379 fn sparse_array(fill_value: Scalar) -> ArrayRef {
380 let mut values = buffer![100i32, 200, 300].into_array();
382 values = cast(&values, fill_value.dtype()).unwrap();
383
384 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
385 .unwrap()
386 .into_array()
387 }
388
389 #[test]
390 pub fn test_scalar_at() {
391 let array = sparse_array(nullable_fill());
392
393 assert_eq!(array.scalar_at(0), nullable_fill());
394 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
395 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
396 }
397
398 #[test]
399 #[should_panic(expected = "out of bounds")]
400 fn test_scalar_at_oob() {
401 let array = sparse_array(nullable_fill());
402 array.scalar_at(10);
403 }
404
405 #[test]
406 pub fn test_scalar_at_again() {
407 let arr = SparseArray::try_new(
408 ConstantArray::new(10u32, 1).into_array(),
409 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
410 100,
411 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
412 )
413 .unwrap();
414
415 assert_eq!(
416 PrimitiveScalar::try_from(&arr.scalar_at(10))
417 .unwrap()
418 .typed_value::<u32>(),
419 Some(1234)
420 );
421 assert!(arr.scalar_at(0).is_null());
422 assert!(arr.scalar_at(99).is_null());
423 }
424
425 #[test]
426 pub fn scalar_at_sliced() {
427 let sliced = sparse_array(nullable_fill()).slice(2..7);
428 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
429 }
430
431 #[test]
432 pub fn validity_mask_sliced_null_fill() {
433 let sliced = sparse_array(nullable_fill()).slice(2..7);
434 assert_eq!(
435 sliced.validity_mask(),
436 Mask::from_iter(vec![true, false, false, true, false])
437 );
438 }
439
440 #[test]
441 pub fn validity_mask_sliced_nonnull_fill() {
442 let sliced = SparseArray::try_new(
443 buffer![2u64, 5, 8].into_array(),
444 ConstantArray::new(
445 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
446 3,
447 )
448 .into_array(),
449 10,
450 Scalar::primitive(1.0f32, Nullability::Nullable),
451 )
452 .unwrap()
453 .slice(2..7);
454
455 assert_eq!(
456 sliced.validity_mask(),
457 Mask::from_iter(vec![false, true, true, false, true])
458 );
459 }
460
461 #[test]
462 pub fn scalar_at_sliced_twice() {
463 let sliced_once = sparse_array(nullable_fill()).slice(1..8);
464 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
465
466 let sliced_twice = sliced_once.slice(1..6);
467 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
468 }
469
470 #[test]
471 pub fn sparse_validity_mask() {
472 let array = sparse_array(nullable_fill());
473 assert_eq!(
474 array.validity_mask().to_bit_buffer().iter().collect_vec(),
475 [
476 false, false, true, false, false, true, false, false, true, false
477 ]
478 );
479 }
480
481 #[test]
482 fn sparse_validity_mask_non_null_fill() {
483 let array = sparse_array(non_nullable_fill());
484 assert!(array.validity_mask().all_true());
485 }
486
487 #[test]
488 #[should_panic]
489 fn test_invalid_length() {
490 let values = buffer![15_u32, 135, 13531, 42].into_array();
491 let indices = buffer![10_u64, 11, 50, 100].into_array();
492
493 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
494 }
495
496 #[test]
497 fn test_valid_length() {
498 let values = buffer![15_u32, 135, 13531, 42].into_array();
499 let indices = buffer![10_u64, 11, 50, 100].into_array();
500
501 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
502 }
503
504 #[test]
505 fn encode_with_nulls() {
506 let sparse = SparseArray::encode(
507 &PrimitiveArray::new(
508 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
509 Validity::from_iter(vec![
510 true, true, false, true, false, true, false, true, true, false, true, false,
511 ]),
512 )
513 .into_array(),
514 None,
515 )
516 .vortex_unwrap();
517 let canonical = sparse.to_primitive();
518 assert_eq!(
519 sparse.validity_mask(),
520 Mask::from_iter(vec![
521 true, true, false, true, false, true, false, true, true, false, true, false,
522 ])
523 );
524 assert_eq!(
525 canonical.as_slice::<i32>(),
526 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
527 );
528 }
529
530 #[test]
531 fn validity_mask_includes_null_values_when_fill_is_null() {
532 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
533 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
534 .into_array();
535 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
536 let actual = array.validity_mask();
537 let expected = Mask::from_iter([
538 true, false, true, false, false, false, false, false, true, false,
539 ]);
540
541 assert_eq!(actual, expected);
542 }
543}