1use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayBufferVisitor;
12use vortex_array::ArrayChildVisitor;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::arrays::ConstantArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::compute::Operator;
24use vortex_array::compute::compare;
25use vortex_array::compute::fill_null;
26use vortex_array::compute::filter;
27use vortex_array::compute::sub_scalar;
28use vortex_array::patches::Patches;
29use vortex_array::patches::PatchesMetadata;
30use vortex_array::serde::ArrayChildren;
31use vortex_array::stats::ArrayStats;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::ArrayVTable;
37use vortex_array::vtable::ArrayVTableExt;
38use vortex_array::vtable::BaseArrayVTable;
39use vortex_array::vtable::EncodeVTable;
40use vortex_array::vtable::NotSupported;
41use vortex_array::vtable::VTable;
42use vortex_array::vtable::ValidityVTable;
43use vortex_array::vtable::VisitorVTable;
44use vortex_buffer::BitBufferMut;
45use vortex_buffer::Buffer;
46use vortex_buffer::ByteBufferMut;
47use vortex_dtype::DType;
48use vortex_dtype::NativePType;
49use vortex_dtype::Nullability;
50use vortex_dtype::match_each_integer_ptype;
51use vortex_error::VortexExpect as _;
52use vortex_error::VortexResult;
53use vortex_error::vortex_bail;
54use vortex_error::vortex_ensure;
55use vortex_mask::AllOr;
56use vortex_mask::Mask;
57use vortex_scalar::Scalar;
58use vortex_scalar::ScalarValue;
59
60mod canonical;
61mod compute;
62mod ops;
63
64vtable!(Sparse);
65
66#[derive(Clone, prost::Message)]
67#[repr(C)]
68pub struct SparseMetadata {
69 #[prost(message, required, tag = "1")]
70 patches: PatchesMetadata,
71}
72
73impl VTable for SparseVTable {
74 type Array = SparseArray;
75
76 type Metadata = ProstMetadata<SparseMetadata>;
77
78 type ArrayVTable = Self;
79 type CanonicalVTable = Self;
80 type OperationsVTable = Self;
81 type ValidityVTable = Self;
82 type VisitorVTable = Self;
83 type ComputeVTable = NotSupported;
84 type EncodeVTable = Self;
85
86 fn id(&self) -> ArrayId {
87 ArrayId::new_ref("vortex.sparse")
88 }
89
90 fn encoding(_array: &Self::Array) -> ArrayVTable {
91 SparseVTable.as_vtable()
92 }
93
94 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
95 Ok(ProstMetadata(SparseMetadata {
96 patches: array.patches().to_metadata(array.len(), array.dtype())?,
97 }))
98 }
99
100 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
101 Ok(Some(metadata.0.encode_to_vec()))
102 }
103
104 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
105 Ok(ProstMetadata(SparseMetadata::decode(buffer)?))
106 }
107
108 fn build(
109 &self,
110 dtype: &DType,
111 len: usize,
112 metadata: &Self::Metadata,
113 buffers: &[BufferHandle],
114 children: &dyn ArrayChildren,
115 ) -> VortexResult<SparseArray> {
116 if children.len() != 2 {
117 vortex_bail!(
118 "Expected 2 children for sparse encoding, found {}",
119 children.len()
120 )
121 }
122 assert_eq!(
123 metadata.0.patches.offset(),
124 0,
125 "Patches must start at offset 0"
126 );
127
128 let patch_indices = children.get(
129 0,
130 &metadata.0.patches.indices_dtype(),
131 metadata.0.patches.len(),
132 )?;
133 let patch_values = children.get(1, dtype, metadata.0.patches.len())?;
134
135 if buffers.len() != 1 {
136 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
137 }
138 let fill_value = Scalar::new(
139 dtype.clone(),
140 ScalarValue::from_protobytes(&buffers[0].clone().try_to_bytes()?)?,
141 );
142
143 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
144 }
145
146 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
147 vortex_ensure!(
148 children.len() == 2,
149 "SparseArray expects 2 children, got {}",
150 children.len()
151 );
152
153 let mut children_iter = children.into_iter();
154 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
155 let patch_values = children_iter.next().vortex_expect("patch_values child");
156
157 array.patches = Patches::new(
158 array.patches.array_len(),
159 array.patches.offset(),
160 patch_indices,
161 patch_values,
162 array.patches.chunk_offsets().clone(),
163 );
164
165 Ok(())
166 }
167}
168
169#[derive(Clone, Debug)]
170pub struct SparseArray {
171 patches: Patches,
172 fill_value: Scalar,
173 stats_set: ArrayStats,
174}
175
176#[derive(Debug)]
177pub struct SparseVTable;
178
179impl SparseArray {
180 pub fn try_new(
181 indices: ArrayRef,
182 values: ArrayRef,
183 len: usize,
184 fill_value: Scalar,
185 ) -> VortexResult<Self> {
186 vortex_ensure!(
187 indices.len() == values.len(),
188 "Mismatched indices {} and values {} length",
189 indices.len(),
190 values.len()
191 );
192
193 vortex_ensure!(
194 indices.statistics().compute_is_strict_sorted() == Some(true),
195 "SparseArray: indices must be strict-sorted"
196 );
197
198 if !indices.is_empty() {
200 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
201
202 vortex_ensure!(
203 last_index < len,
204 "Array length was {len} but the last index is {last_index}"
205 );
206 }
207
208 Ok(Self {
209 patches: Patches::new(len, 0, indices, values, None),
211 fill_value,
212 stats_set: Default::default(),
213 })
214 }
215
216 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
218 vortex_ensure!(
219 fill_value.dtype() == patches.values().dtype(),
220 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
221 fill_value,
222 patches.values().dtype(),
223 fill_value.dtype(),
224 );
225
226 Ok(Self {
227 patches,
228 fill_value,
229 stats_set: Default::default(),
230 })
231 }
232
233 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
234 Self {
235 patches,
236 fill_value,
237 stats_set: Default::default(),
238 }
239 }
240
241 #[inline]
242 pub fn patches(&self) -> &Patches {
243 &self.patches
244 }
245
246 #[inline]
247 pub fn resolved_patches(&self) -> Patches {
248 let patches = self.patches();
249 let indices_offset = Scalar::from(patches.offset())
250 .cast(patches.indices().dtype())
251 .vortex_expect("Patches offset must cast to the indices dtype");
252 let indices = sub_scalar(patches.indices(), indices_offset)
253 .vortex_expect("must be able to subtract offset from indices");
254
255 Patches::new(
256 patches.array_len(),
257 0,
258 indices,
259 patches.values().clone(),
260 None,
262 )
263 }
264
265 #[inline]
266 pub fn fill_scalar(&self) -> &Scalar {
267 &self.fill_value
268 }
269
270 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
274 if let Some(fill_value) = fill_value.as_ref()
275 && array.dtype() != fill_value.dtype()
276 {
277 vortex_bail!(
278 "Array and fill value types must match. got {} and {}",
279 array.dtype(),
280 fill_value.dtype()
281 )
282 }
283 let mask = array.validity_mask();
284
285 if mask.all_false() {
286 return Ok(
288 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
289 );
290 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
291 let non_null_values = filter(array, &mask)?;
293 let non_null_indices = match mask.indices() {
294 AllOr::All => {
295 unreachable!("Mask is mostly null")
297 }
298 AllOr::None => {
299 unreachable!("Mask is mostly null but not all null")
301 }
302 AllOr::Some(values) => {
303 let buffer: Buffer<u32> = values
304 .iter()
305 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
306 .collect();
307
308 buffer.into_array()
309 }
310 };
311
312 return Ok(SparseArray::try_new(
313 non_null_indices,
314 non_null_values,
315 array.len(),
316 Scalar::null(array.dtype().clone()),
317 )?
318 .into_array());
319 }
320
321 let fill = if let Some(fill) = fill_value {
322 fill
323 } else {
324 let (top_pvalue, _) = array
326 .to_primitive()
327 .top_value()?
328 .vortex_expect("Non empty or all null array");
329
330 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
331 };
332
333 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
334 let non_top_mask = Mask::from_buffer(
335 fill_null(
336 &compare(array, &fill_array, Operator::NotEq)?,
337 &Scalar::bool(true, Nullability::NonNullable),
338 )?
339 .to_bool()
340 .bit_buffer()
341 .clone(),
342 );
343
344 let non_top_values = filter(array, &non_top_mask)?;
345
346 let indices: Buffer<u64> = match non_top_mask {
347 Mask::AllTrue(count) => {
348 (0u64..count as u64).collect()
350 }
351 Mask::AllFalse(_) => {
352 return Ok(fill_array);
354 }
355 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
356 };
357
358 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
359 .map(|a| a.into_array())
360 }
361}
362
363impl BaseArrayVTable<SparseVTable> for SparseVTable {
364 fn len(array: &SparseArray) -> usize {
365 array.patches.array_len()
366 }
367
368 fn dtype(array: &SparseArray) -> &DType {
369 array.fill_scalar().dtype()
370 }
371
372 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
373 array.stats_set.to_ref(array.as_ref())
374 }
375
376 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
377 array.patches.array_hash(state, precision);
378 array.fill_value.hash(state);
379 }
380
381 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
382 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
383 }
384}
385
386impl ValidityVTable<SparseVTable> for SparseVTable {
387 fn is_valid(array: &SparseArray, index: usize) -> bool {
388 match array.patches().get_patched(index) {
389 None => array.fill_scalar().is_valid(),
390 Some(patch_value) => patch_value.is_valid(),
391 }
392 }
393
394 fn all_valid(array: &SparseArray) -> bool {
395 if array.fill_scalar().is_null() {
396 return array.patches().values().len() == array.len()
398 && array.patches().values().all_valid();
399 }
400
401 array.patches().values().all_valid()
402 }
403
404 fn all_invalid(array: &SparseArray) -> bool {
405 if !array.fill_scalar().is_null() {
406 return array.patches().values().len() == array.len()
408 && array.patches().values().all_invalid();
409 }
410
411 array.patches().values().all_invalid()
412 }
413
414 fn validity(array: &SparseArray) -> VortexResult<Validity> {
415 let patches = unsafe {
416 Patches::new_unchecked(
417 array.patches.array_len(),
418 array.patches.offset(),
419 array.patches.indices().clone(),
420 array
421 .patches
422 .values()
423 .validity()?
424 .to_array(array.patches.values().len()),
425 array.patches.chunk_offsets().clone(),
426 array.patches.offset_within_chunk(),
427 )
428 };
429
430 Ok(Validity::Array(
431 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
432 .into_array(),
433 ))
434 }
435
436 fn validity_mask(array: &SparseArray) -> Mask {
437 let fill_is_valid = array.fill_scalar().is_valid();
438 let values_validity = array.patches().values().validity_mask();
439 let len = array.len();
440
441 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
442 return Mask::AllTrue(len);
443 }
444 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
445 return Mask::AllFalse(len);
446 }
447
448 let mut is_valid_buffer = if fill_is_valid {
449 BitBufferMut::new_set(len)
450 } else {
451 BitBufferMut::new_unset(len)
452 };
453
454 let indices = array.patches().indices().to_primitive();
455 let index_offset = array.patches().offset();
456
457 match_each_integer_ptype!(indices.ptype(), |I| {
458 let indices = indices.as_slice::<I>();
459 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
460 });
461
462 Mask::from_buffer(is_valid_buffer.freeze())
463 }
464}
465
466fn patch_validity<I: NativePType + AsPrimitive<usize>>(
467 is_valid_buffer: &mut BitBufferMut,
468 indices: &[I],
469 index_offset: usize,
470 values_validity: Mask,
471) {
472 let indices = indices.iter().map(|index| index.as_() - index_offset);
473 match values_validity {
474 Mask::AllTrue(_) => {
475 for index in indices {
476 is_valid_buffer.set(index);
477 }
478 }
479 Mask::AllFalse(_) => {
480 for index in indices {
481 is_valid_buffer.unset(index);
482 }
483 }
484 Mask::Values(mask_values) => {
485 let is_valid = mask_values.bit_buffer().iter();
486 for (index, is_valid) in indices.zip_eq(is_valid) {
487 is_valid_buffer.set_to(index, is_valid);
488 }
489 }
490 }
491}
492
493impl EncodeVTable<SparseVTable> for SparseVTable {
494 fn encode(
495 _vtable: &SparseVTable,
496 input: &Canonical,
497 like: Option<&SparseArray>,
498 ) -> VortexResult<Option<SparseArray>> {
499 let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
501
502 Ok(SparseArray::encode(input.as_ref(), fill_value)?
504 .as_opt::<SparseVTable>()
505 .cloned())
506 }
507}
508
509impl VisitorVTable<SparseVTable> for SparseVTable {
510 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
511 let fill_value_buffer = array
512 .fill_value
513 .value()
514 .to_protobytes::<ByteBufferMut>()
515 .freeze();
516 visitor.visit_buffer(&fill_value_buffer);
517 }
518
519 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
520 visitor.visit_patches(array.patches())
521 }
522}
523
524#[cfg(test)]
525mod test {
526 use itertools::Itertools;
527 use vortex_array::IntoArray;
528 use vortex_array::arrays::ConstantArray;
529 use vortex_array::arrays::PrimitiveArray;
530 use vortex_array::compute::cast;
531 use vortex_array::validity::Validity;
532 use vortex_buffer::buffer;
533 use vortex_dtype::DType;
534 use vortex_dtype::Nullability;
535 use vortex_dtype::PType;
536 use vortex_error::VortexExpect;
537 use vortex_scalar::PrimitiveScalar;
538 use vortex_scalar::Scalar;
539
540 use super::*;
541
542 fn nullable_fill() -> Scalar {
543 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
544 }
545
546 fn non_nullable_fill() -> Scalar {
547 Scalar::from(42i32)
548 }
549
550 fn sparse_array(fill_value: Scalar) -> ArrayRef {
551 let mut values = buffer![100i32, 200, 300].into_array();
553 values = cast(&values, fill_value.dtype()).unwrap();
554
555 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
556 .unwrap()
557 .into_array()
558 }
559
560 #[test]
561 pub fn test_scalar_at() {
562 let array = sparse_array(nullable_fill());
563
564 assert_eq!(array.scalar_at(0), nullable_fill());
565 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
566 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
567 }
568
569 #[test]
570 #[should_panic(expected = "out of bounds")]
571 fn test_scalar_at_oob() {
572 let array = sparse_array(nullable_fill());
573 array.scalar_at(10);
574 }
575
576 #[test]
577 pub fn test_scalar_at_again() {
578 let arr = SparseArray::try_new(
579 ConstantArray::new(10u32, 1).into_array(),
580 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
581 100,
582 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
583 )
584 .unwrap();
585
586 assert_eq!(
587 PrimitiveScalar::try_from(&arr.scalar_at(10))
588 .unwrap()
589 .typed_value::<u32>(),
590 Some(1234)
591 );
592 assert!(arr.scalar_at(0).is_null());
593 assert!(arr.scalar_at(99).is_null());
594 }
595
596 #[test]
597 pub fn scalar_at_sliced() {
598 let sliced = sparse_array(nullable_fill()).slice(2..7);
599 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
600 }
601
602 #[test]
603 pub fn validity_mask_sliced_null_fill() {
604 let sliced = sparse_array(nullable_fill()).slice(2..7);
605 assert_eq!(
606 sliced.validity_mask(),
607 Mask::from_iter(vec![true, false, false, true, false])
608 );
609 }
610
611 #[test]
612 pub fn validity_mask_sliced_nonnull_fill() {
613 let sliced = SparseArray::try_new(
614 buffer![2u64, 5, 8].into_array(),
615 ConstantArray::new(
616 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
617 3,
618 )
619 .into_array(),
620 10,
621 Scalar::primitive(1.0f32, Nullability::Nullable),
622 )
623 .unwrap()
624 .slice(2..7);
625
626 assert_eq!(
627 sliced.validity_mask(),
628 Mask::from_iter(vec![false, true, true, false, true])
629 );
630 }
631
632 #[test]
633 pub fn scalar_at_sliced_twice() {
634 let sliced_once = sparse_array(nullable_fill()).slice(1..8);
635 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
636
637 let sliced_twice = sliced_once.slice(1..6);
638 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
639 }
640
641 #[test]
642 pub fn sparse_validity_mask() {
643 let array = sparse_array(nullable_fill());
644 assert_eq!(
645 array.validity_mask().to_bit_buffer().iter().collect_vec(),
646 [
647 false, false, true, false, false, true, false, false, true, false
648 ]
649 );
650 }
651
652 #[test]
653 fn sparse_validity_mask_non_null_fill() {
654 let array = sparse_array(non_nullable_fill());
655 assert!(array.validity_mask().all_true());
656 }
657
658 #[test]
659 #[should_panic]
660 fn test_invalid_length() {
661 let values = buffer![15_u32, 135, 13531, 42].into_array();
662 let indices = buffer![10_u64, 11, 50, 100].into_array();
663
664 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
665 }
666
667 #[test]
668 fn test_valid_length() {
669 let values = buffer![15_u32, 135, 13531, 42].into_array();
670 let indices = buffer![10_u64, 11, 50, 100].into_array();
671
672 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
673 }
674
675 #[test]
676 fn encode_with_nulls() {
677 let sparse = SparseArray::encode(
678 &PrimitiveArray::new(
679 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
680 Validity::from_iter(vec![
681 true, true, false, true, false, true, false, true, true, false, true, false,
682 ]),
683 )
684 .into_array(),
685 None,
686 )
687 .vortex_expect("SparseArray::encode should succeed for test data");
688 let canonical = sparse.to_primitive();
689 assert_eq!(
690 sparse.validity_mask(),
691 Mask::from_iter(vec![
692 true, true, false, true, false, true, false, true, true, false, true, false,
693 ])
694 );
695 assert_eq!(
696 canonical.as_slice::<i32>(),
697 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
698 );
699 }
700
701 #[test]
702 fn validity_mask_includes_null_values_when_fill_is_null() {
703 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
704 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
705 .into_array();
706 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
707 let actual = array.validity_mask();
708 let expected = Mask::from_iter([
709 true, false, true, false, false, false, false, false, true, false,
710 ]);
711
712 assert_eq!(actual, expected);
713 }
714}