1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayBufferVisitor;
12use vortex_array::ArrayChildVisitor;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::arrays::ConstantArray;
22use vortex_array::compute::Operator;
23use vortex_array::compute::compare;
24use vortex_array::compute::fill_null;
25use vortex_array::compute::filter;
26use vortex_array::compute::sub_scalar;
27use vortex_array::patches::Patches;
28use vortex_array::patches::PatchesMetadata;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::vtable;
33use vortex_array::vtable::ArrayId;
34use vortex_array::vtable::ArrayVTable;
35use vortex_array::vtable::ArrayVTableExt;
36use vortex_array::vtable::BaseArrayVTable;
37use vortex_array::vtable::EncodeVTable;
38use vortex_array::vtable::NotSupported;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityVTable;
41use vortex_array::vtable::VisitorVTable;
42use vortex_buffer::BitBufferMut;
43use vortex_buffer::Buffer;
44use vortex_buffer::BufferHandle;
45use vortex_buffer::ByteBufferMut;
46use vortex_dtype::DType;
47use vortex_dtype::NativePType;
48use vortex_dtype::Nullability;
49use vortex_dtype::match_each_integer_ptype;
50use vortex_error::VortexExpect as _;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_mask::AllOr;
55use vortex_mask::Mask;
56use vortex_scalar::Scalar;
57use vortex_scalar::ScalarValue;
58
59mod canonical;
60mod compute;
61mod ops;
62
63vtable!(Sparse);
64
65#[derive(Clone, prost::Message)]
66#[repr(C)]
67pub struct SparseMetadata {
68 #[prost(message, required, tag = "1")]
69 patches: PatchesMetadata,
70}
71
72impl VTable for SparseVTable {
73 type Array = SparseArray;
74
75 type Metadata = ProstMetadata<SparseMetadata>;
76
77 type ArrayVTable = Self;
78 type CanonicalVTable = Self;
79 type OperationsVTable = Self;
80 type ValidityVTable = Self;
81 type VisitorVTable = Self;
82 type ComputeVTable = NotSupported;
83 type EncodeVTable = Self;
84
85 fn id(&self) -> ArrayId {
86 ArrayId::new_ref("vortex.sparse")
87 }
88
89 fn encoding(_array: &Self::Array) -> ArrayVTable {
90 SparseVTable.as_vtable()
91 }
92
93 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
94 Ok(ProstMetadata(SparseMetadata {
95 patches: array.patches().to_metadata(array.len(), array.dtype())?,
96 }))
97 }
98
99 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100 Ok(Some(metadata.0.encode_to_vec()))
101 }
102
103 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
104 Ok(ProstMetadata(SparseMetadata::decode(buffer)?))
105 }
106
107 fn build(
108 &self,
109 dtype: &DType,
110 len: usize,
111 metadata: &Self::Metadata,
112 buffers: &[BufferHandle],
113 children: &dyn ArrayChildren,
114 ) -> VortexResult<SparseArray> {
115 if children.len() != 2 {
116 vortex_bail!(
117 "Expected 2 children for sparse encoding, found {}",
118 children.len()
119 )
120 }
121 assert_eq!(
122 metadata.0.patches.offset(),
123 0,
124 "Patches must start at offset 0"
125 );
126
127 let patch_indices = children.get(
128 0,
129 &metadata.0.patches.indices_dtype(),
130 metadata.0.patches.len(),
131 )?;
132 let patch_values = children.get(1, dtype, metadata.0.patches.len())?;
133
134 if buffers.len() != 1 {
135 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
136 }
137 let fill_value = Scalar::new(
138 dtype.clone(),
139 ScalarValue::from_protobytes(&buffers[0].clone().try_to_bytes()?)?,
140 );
141
142 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
143 }
144}
145
146#[derive(Clone, Debug)]
147pub struct SparseArray {
148 patches: Patches,
149 fill_value: Scalar,
150 stats_set: ArrayStats,
151}
152
153#[derive(Debug)]
154pub struct SparseVTable;
155
156impl SparseArray {
157 pub fn try_new(
158 indices: ArrayRef,
159 values: ArrayRef,
160 len: usize,
161 fill_value: Scalar,
162 ) -> VortexResult<Self> {
163 vortex_ensure!(
164 indices.len() == values.len(),
165 "Mismatched indices {} and values {} length",
166 indices.len(),
167 values.len()
168 );
169
170 vortex_ensure!(
171 indices.statistics().compute_is_strict_sorted() == Some(true),
172 "SparseArray: indices must be strict-sorted"
173 );
174
175 if !indices.is_empty() {
177 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
178
179 vortex_ensure!(
180 last_index < len,
181 "Array length was {len} but the last index is {last_index}"
182 );
183 }
184
185 Ok(Self {
186 patches: Patches::new(len, 0, indices, values, None),
188 fill_value,
189 stats_set: Default::default(),
190 })
191 }
192
193 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
195 vortex_ensure!(
196 fill_value.dtype() == patches.values().dtype(),
197 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
198 fill_value,
199 patches.values().dtype(),
200 fill_value.dtype(),
201 );
202
203 Ok(Self {
204 patches,
205 fill_value,
206 stats_set: Default::default(),
207 })
208 }
209
210 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
211 Self {
212 patches,
213 fill_value,
214 stats_set: Default::default(),
215 }
216 }
217
218 #[inline]
219 pub fn patches(&self) -> &Patches {
220 &self.patches
221 }
222
223 #[inline]
224 pub fn resolved_patches(&self) -> Patches {
225 let patches = self.patches();
226 let indices_offset = Scalar::from(patches.offset())
227 .cast(patches.indices().dtype())
228 .vortex_expect("Patches offset must cast to the indices dtype");
229 let indices = sub_scalar(patches.indices(), indices_offset)
230 .vortex_expect("must be able to subtract offset from indices");
231
232 Patches::new(
233 patches.array_len(),
234 0,
235 indices,
236 patches.values().clone(),
237 None,
239 )
240 }
241
242 #[inline]
243 pub fn fill_scalar(&self) -> &Scalar {
244 &self.fill_value
245 }
246
247 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
251 if let Some(fill_value) = fill_value.as_ref()
252 && array.dtype() != fill_value.dtype()
253 {
254 vortex_bail!(
255 "Array and fill value types must match. got {} and {}",
256 array.dtype(),
257 fill_value.dtype()
258 )
259 }
260 let mask = array.validity_mask();
261
262 if mask.all_false() {
263 return Ok(
265 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
266 );
267 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
268 let non_null_values = filter(array, &mask)?;
270 let non_null_indices = match mask.indices() {
271 AllOr::All => {
272 unreachable!("Mask is mostly null")
274 }
275 AllOr::None => {
276 unreachable!("Mask is mostly null but not all null")
278 }
279 AllOr::Some(values) => {
280 let buffer: Buffer<u32> = values
281 .iter()
282 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
283 .collect();
284
285 buffer.into_array()
286 }
287 };
288
289 return Ok(SparseArray::try_new(
290 non_null_indices,
291 non_null_values,
292 array.len(),
293 Scalar::null(array.dtype().clone()),
294 )?
295 .into_array());
296 }
297
298 let fill = if let Some(fill) = fill_value {
299 fill
300 } else {
301 let (top_pvalue, _) = array
303 .to_primitive()
304 .top_value()?
305 .vortex_expect("Non empty or all null array");
306
307 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
308 };
309
310 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
311 let non_top_mask = Mask::from_buffer(
312 fill_null(
313 &compare(array, &fill_array, Operator::NotEq)?,
314 &Scalar::bool(true, Nullability::NonNullable),
315 )?
316 .to_bool()
317 .bit_buffer()
318 .clone(),
319 );
320
321 let non_top_values = filter(array, &non_top_mask)?;
322
323 let indices: Buffer<u64> = match non_top_mask {
324 Mask::AllTrue(count) => {
325 (0u64..count as u64).collect()
327 }
328 Mask::AllFalse(_) => {
329 return Ok(fill_array);
331 }
332 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
333 };
334
335 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
336 .map(|a| a.into_array())
337 }
338}
339
340impl BaseArrayVTable<SparseVTable> for SparseVTable {
341 fn len(array: &SparseArray) -> usize {
342 array.patches.array_len()
343 }
344
345 fn dtype(array: &SparseArray) -> &DType {
346 array.fill_scalar().dtype()
347 }
348
349 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
350 array.stats_set.to_ref(array.as_ref())
351 }
352
353 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
354 array.patches.array_hash(state, precision);
355 array.fill_value.hash(state);
356 }
357
358 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
359 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
360 }
361}
362
363impl ValidityVTable<SparseVTable> for SparseVTable {
364 fn is_valid(array: &SparseArray, index: usize) -> bool {
365 match array.patches().get_patched(index) {
366 None => array.fill_scalar().is_valid(),
367 Some(patch_value) => patch_value.is_valid(),
368 }
369 }
370
371 fn all_valid(array: &SparseArray) -> bool {
372 if array.fill_scalar().is_null() {
373 return array.patches().values().len() == array.len()
375 && array.patches().values().all_valid();
376 }
377
378 array.patches().values().all_valid()
379 }
380
381 fn all_invalid(array: &SparseArray) -> bool {
382 if !array.fill_scalar().is_null() {
383 return array.patches().values().len() == array.len()
385 && array.patches().values().all_invalid();
386 }
387
388 array.patches().values().all_invalid()
389 }
390
391 fn validity_mask(array: &SparseArray) -> Mask {
392 let fill_is_valid = array.fill_scalar().is_valid();
393 let values_validity = array.patches().values().validity_mask();
394 let len = array.len();
395
396 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
397 return Mask::AllTrue(len);
398 }
399 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
400 return Mask::AllFalse(len);
401 }
402
403 let mut is_valid_buffer = if fill_is_valid {
404 BitBufferMut::new_set(len)
405 } else {
406 BitBufferMut::new_unset(len)
407 };
408
409 let indices = array.patches().indices().to_primitive();
410 let index_offset = array.patches().offset();
411
412 match_each_integer_ptype!(indices.ptype(), |I| {
413 let indices = indices.as_slice::<I>();
414 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
415 });
416
417 Mask::from_buffer(is_valid_buffer.freeze())
418 }
419}
420
421fn patch_validity<I: NativePType + AsPrimitive<usize>>(
422 is_valid_buffer: &mut BitBufferMut,
423 indices: &[I],
424 index_offset: usize,
425 values_validity: Mask,
426) {
427 let indices = indices.iter().map(|index| index.as_() - index_offset);
428 match values_validity {
429 Mask::AllTrue(_) => {
430 for index in indices {
431 is_valid_buffer.set(index);
432 }
433 }
434 Mask::AllFalse(_) => {
435 for index in indices {
436 is_valid_buffer.unset(index);
437 }
438 }
439 Mask::Values(mask_values) => {
440 let is_valid = mask_values.bit_buffer().iter();
441 for (index, is_valid) in indices.zip_eq(is_valid) {
442 is_valid_buffer.set_to(index, is_valid);
443 }
444 }
445 }
446}
447
448impl EncodeVTable<SparseVTable> for SparseVTable {
449 fn encode(
450 _vtable: &SparseVTable,
451 input: &Canonical,
452 like: Option<&SparseArray>,
453 ) -> VortexResult<Option<SparseArray>> {
454 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
456
457 Ok(SparseArray::encode(input.as_ref(), fill_value)?
459 .as_opt::<SparseVTable>()
460 .cloned())
461 }
462}
463
464impl VisitorVTable<SparseVTable> for SparseVTable {
465 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
466 let fill_value_buffer = array
467 .fill_value
468 .value()
469 .to_protobytes::<ByteBufferMut>()
470 .freeze();
471 visitor.visit_buffer(&fill_value_buffer);
472 }
473
474 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
475 visitor.visit_patches(array.patches())
476 }
477}
478
479#[cfg(test)]
480mod test {
481 use itertools::Itertools;
482 use vortex_array::IntoArray;
483 use vortex_array::arrays::ConstantArray;
484 use vortex_array::arrays::PrimitiveArray;
485 use vortex_array::compute::cast;
486 use vortex_array::validity::Validity;
487 use vortex_buffer::buffer;
488 use vortex_dtype::DType;
489 use vortex_dtype::Nullability;
490 use vortex_dtype::PType;
491 use vortex_error::VortexUnwrap;
492 use vortex_scalar::PrimitiveScalar;
493 use vortex_scalar::Scalar;
494
495 use super::*;
496
497 fn nullable_fill() -> Scalar {
498 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
499 }
500
501 fn non_nullable_fill() -> Scalar {
502 Scalar::from(42i32)
503 }
504
505 fn sparse_array(fill_value: Scalar) -> ArrayRef {
506 let mut values = buffer![100i32, 200, 300].into_array();
508 values = cast(&values, fill_value.dtype()).unwrap();
509
510 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
511 .unwrap()
512 .into_array()
513 }
514
515 #[test]
516 pub fn test_scalar_at() {
517 let array = sparse_array(nullable_fill());
518
519 assert_eq!(array.scalar_at(0), nullable_fill());
520 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
521 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
522 }
523
524 #[test]
525 #[should_panic(expected = "out of bounds")]
526 fn test_scalar_at_oob() {
527 let array = sparse_array(nullable_fill());
528 array.scalar_at(10);
529 }
530
531 #[test]
532 pub fn test_scalar_at_again() {
533 let arr = SparseArray::try_new(
534 ConstantArray::new(10u32, 1).into_array(),
535 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
536 100,
537 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
538 )
539 .unwrap();
540
541 assert_eq!(
542 PrimitiveScalar::try_from(&arr.scalar_at(10))
543 .unwrap()
544 .typed_value::<u32>(),
545 Some(1234)
546 );
547 assert!(arr.scalar_at(0).is_null());
548 assert!(arr.scalar_at(99).is_null());
549 }
550
551 #[test]
552 pub fn scalar_at_sliced() {
553 let sliced = sparse_array(nullable_fill()).slice(2..7);
554 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
555 }
556
557 #[test]
558 pub fn validity_mask_sliced_null_fill() {
559 let sliced = sparse_array(nullable_fill()).slice(2..7);
560 assert_eq!(
561 sliced.validity_mask(),
562 Mask::from_iter(vec![true, false, false, true, false])
563 );
564 }
565
566 #[test]
567 pub fn validity_mask_sliced_nonnull_fill() {
568 let sliced = SparseArray::try_new(
569 buffer![2u64, 5, 8].into_array(),
570 ConstantArray::new(
571 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
572 3,
573 )
574 .into_array(),
575 10,
576 Scalar::primitive(1.0f32, Nullability::Nullable),
577 )
578 .unwrap()
579 .slice(2..7);
580
581 assert_eq!(
582 sliced.validity_mask(),
583 Mask::from_iter(vec![false, true, true, false, true])
584 );
585 }
586
587 #[test]
588 pub fn scalar_at_sliced_twice() {
589 let sliced_once = sparse_array(nullable_fill()).slice(1..8);
590 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
591
592 let sliced_twice = sliced_once.slice(1..6);
593 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
594 }
595
596 #[test]
597 pub fn sparse_validity_mask() {
598 let array = sparse_array(nullable_fill());
599 assert_eq!(
600 array.validity_mask().to_bit_buffer().iter().collect_vec(),
601 [
602 false, false, true, false, false, true, false, false, true, false
603 ]
604 );
605 }
606
607 #[test]
608 fn sparse_validity_mask_non_null_fill() {
609 let array = sparse_array(non_nullable_fill());
610 assert!(array.validity_mask().all_true());
611 }
612
613 #[test]
614 #[should_panic]
615 fn test_invalid_length() {
616 let values = buffer![15_u32, 135, 13531, 42].into_array();
617 let indices = buffer![10_u64, 11, 50, 100].into_array();
618
619 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
620 }
621
622 #[test]
623 fn test_valid_length() {
624 let values = buffer![15_u32, 135, 13531, 42].into_array();
625 let indices = buffer![10_u64, 11, 50, 100].into_array();
626
627 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
628 }
629
630 #[test]
631 fn encode_with_nulls() {
632 let sparse = SparseArray::encode(
633 &PrimitiveArray::new(
634 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
635 Validity::from_iter(vec![
636 true, true, false, true, false, true, false, true, true, false, true, false,
637 ]),
638 )
639 .into_array(),
640 None,
641 )
642 .vortex_unwrap();
643 let canonical = sparse.to_primitive();
644 assert_eq!(
645 sparse.validity_mask(),
646 Mask::from_iter(vec![
647 true, true, false, true, false, true, false, true, true, false, true, false,
648 ])
649 );
650 assert_eq!(
651 canonical.as_slice::<i32>(),
652 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
653 );
654 }
655
656 #[test]
657 fn validity_mask_includes_null_values_when_fill_is_null() {
658 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
659 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
660 .into_array();
661 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
662 let actual = array.validity_mask();
663 let expected = Mask::from_iter([
664 true, false, true, false, false, false, false, false, true, false,
665 ]);
666
667 assert_eq!(actual, expected);
668 }
669}