1use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DynArray;
13use vortex_array::ExecutionCtx;
14use vortex_array::ExecutionStep;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ToCanonical;
18use vortex_array::arrays::ConstantArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::builtins::ArrayBuiltins;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::Nullability;
23use vortex_array::patches::Patches;
24use vortex_array::patches::PatchesMetadata;
25use vortex_array::scalar::Scalar;
26use vortex_array::scalar::ScalarValue;
27use vortex_array::scalar_fn::fns::operators::Operator;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSetRef;
31use vortex_array::validity::Validity;
32use vortex_array::vtable;
33use vortex_array::vtable::ArrayId;
34use vortex_array::vtable::VTable;
35use vortex_array::vtable::ValidityVTable;
36use vortex_array::vtable::patches_child;
37use vortex_array::vtable::patches_child_name;
38use vortex_array::vtable::patches_nchildren;
39use vortex_buffer::Buffer;
40use vortex_buffer::ByteBufferMut;
41use vortex_error::VortexExpect as _;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_ensure_eq;
46use vortex_error::vortex_panic;
47use vortex_mask::AllOr;
48use vortex_mask::Mask;
49use vortex_session::VortexSession;
50
51use crate::canonical::execute_sparse;
52use crate::rules::RULES;
53
54mod canonical;
55mod compute;
56mod kernel;
57mod ops;
58mod rules;
59mod slice;
60
61vtable!(Sparse);
62
63#[derive(Debug)]
64pub struct SparseMetadata {
65 patches: PatchesMetadata,
66 fill_value: Scalar,
67}
68
69#[derive(Clone, prost::Message)]
70#[repr(C)]
71pub struct ProstPatchesMetadata {
72 #[prost(message, required, tag = "1")]
73 patches: PatchesMetadata,
74}
75
76impl VTable for SparseVTable {
77 type Array = SparseArray;
78
79 type Metadata = SparseMetadata;
80 type OperationsVTable = Self;
81 type ValidityVTable = Self;
82
83 fn id(_array: &Self::Array) -> ArrayId {
84 Self::ID
85 }
86
87 fn len(array: &SparseArray) -> usize {
88 array.patches.array_len()
89 }
90
91 fn dtype(array: &SparseArray) -> &DType {
92 array.fill_scalar().dtype()
93 }
94
95 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
96 array.stats_set.to_ref(array.as_ref())
97 }
98
99 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
100 array.patches.array_hash(state, precision);
101 array.fill_value.hash(state);
102 }
103
104 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
105 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
106 }
107
108 fn nbuffers(_array: &SparseArray) -> usize {
109 1
110 }
111
112 fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
113 match idx {
114 0 => {
115 let fill_value_buffer =
116 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
117 BufferHandle::new_host(fill_value_buffer)
118 }
119 _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
120 }
121 }
122
123 fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
124 match idx {
125 0 => Some("fill_value".to_string()),
126 _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
127 }
128 }
129
130 fn nchildren(array: &SparseArray) -> usize {
131 patches_nchildren(array.patches())
132 }
133
134 fn child(array: &SparseArray, idx: usize) -> ArrayRef {
135 patches_child(array.patches(), idx)
136 }
137
138 fn child_name(_array: &SparseArray, idx: usize) -> String {
139 patches_child_name(idx).to_string()
140 }
141
142 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
143 let patches = array.patches().to_metadata(array.len(), array.dtype())?;
144
145 Ok(SparseMetadata {
146 patches,
147 fill_value: array.fill_value.clone(),
148 })
149 }
150
151 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
152 let prost_patches = ProstPatchesMetadata {
153 patches: metadata.patches,
154 };
155
156 Ok(Some(prost_patches.encode_to_vec()))
158 }
159
160 fn deserialize(
161 bytes: &[u8],
162 dtype: &DType,
163 _len: usize,
164 buffers: &[BufferHandle],
165 session: &VortexSession,
166 ) -> VortexResult<Self::Metadata> {
167 let prost_patches = ProstPatchesMetadata::decode(bytes)?;
168
169 if buffers.len() != 1 {
172 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
173 }
174 let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
175
176 let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
177 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
178
179 Ok(SparseMetadata {
180 patches: prost_patches.patches,
181 fill_value,
182 })
183 }
184
185 fn build(
186 dtype: &DType,
187 len: usize,
188 metadata: &Self::Metadata,
189 _buffers: &[BufferHandle],
190 children: &dyn ArrayChildren,
191 ) -> VortexResult<SparseArray> {
192 vortex_ensure_eq!(
193 children.len(),
194 2,
195 "SparseArray expects 2 children for sparse encoding, found {}",
196 children.len()
197 );
198 vortex_ensure_eq!(
199 metadata.patches.offset()?,
200 0,
201 "Patches must start at offset 0"
202 );
203
204 let patch_indices = children.get(
205 0,
206 &metadata.patches.indices_dtype()?,
207 metadata.patches.len()?,
208 )?;
209 let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
210
211 SparseArray::try_new(
212 patch_indices,
213 patch_values,
214 len,
215 metadata.fill_value.clone(),
216 )
217 }
218
219 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
220 vortex_ensure_eq!(
221 children.len(),
222 2,
223 "SparseArray expects 2 children, got {}",
224 children.len()
225 );
226
227 let mut children_iter = children.into_iter();
228 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
229 let patch_values = children_iter.next().vortex_expect("patch_values child");
230
231 array.patches = Patches::new(
232 array.patches.array_len(),
233 array.patches.offset(),
234 patch_indices,
235 patch_values,
236 array.patches.chunk_offsets().clone(),
237 )?;
238
239 Ok(())
240 }
241
242 fn reduce_parent(
243 array: &Self::Array,
244 parent: &ArrayRef,
245 child_idx: usize,
246 ) -> VortexResult<Option<ArrayRef>> {
247 RULES.evaluate(array, parent, child_idx)
248 }
249
250 fn execute_parent(
251 array: &Self::Array,
252 parent: &ArrayRef,
253 child_idx: usize,
254 ctx: &mut ExecutionCtx,
255 ) -> VortexResult<Option<ArrayRef>> {
256 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
257 }
258
259 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
260 execute_sparse(array, ctx).map(ExecutionStep::Done)
261 }
262}
263
264#[derive(Clone, Debug)]
265pub struct SparseArray {
266 patches: Patches,
267 fill_value: Scalar,
268 stats_set: ArrayStats,
269}
270
271#[derive(Debug)]
272pub struct SparseVTable;
273
274impl SparseVTable {
275 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
276}
277
278impl SparseArray {
279 pub fn try_new(
280 indices: ArrayRef,
281 values: ArrayRef,
282 len: usize,
283 fill_value: Scalar,
284 ) -> VortexResult<Self> {
285 vortex_ensure!(
286 indices.len() == values.len(),
287 "Mismatched indices {} and values {} length",
288 indices.len(),
289 values.len()
290 );
291
292 if indices.is_host() {
293 debug_assert_eq!(
294 indices.statistics().compute_is_strict_sorted(),
295 Some(true),
296 "SparseArray: indices must be strict-sorted"
297 );
298
299 if !indices.is_empty() {
301 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
302
303 vortex_ensure!(
304 last_index < len,
305 "Array length was {len} but the last index is {last_index}"
306 );
307 }
308 }
309
310 Ok(Self {
311 patches: Patches::new(len, 0, indices, values, None)?,
313 fill_value,
314 stats_set: Default::default(),
315 })
316 }
317
318 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
320 vortex_ensure!(
321 fill_value.dtype() == patches.values().dtype(),
322 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
323 fill_value,
324 patches.values().dtype(),
325 fill_value.dtype(),
326 );
327
328 Ok(Self {
329 patches,
330 fill_value,
331 stats_set: Default::default(),
332 })
333 }
334
335 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
336 Self {
337 patches,
338 fill_value,
339 stats_set: Default::default(),
340 }
341 }
342
343 #[inline]
344 pub fn patches(&self) -> &Patches {
345 &self.patches
346 }
347
348 #[inline]
349 pub fn resolved_patches(&self) -> VortexResult<Patches> {
350 let patches = self.patches();
351 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
352 let indices = patches.indices().to_array().binary(
353 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
354 Operator::Sub,
355 )?;
356
357 Patches::new(
358 patches.array_len(),
359 0,
360 indices,
361 patches.values().clone(),
362 None,
364 )
365 }
366
367 #[inline]
368 pub fn fill_scalar(&self) -> &Scalar {
369 &self.fill_value
370 }
371
372 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
376 if let Some(fill_value) = fill_value.as_ref()
377 && array.dtype() != fill_value.dtype()
378 {
379 vortex_bail!(
380 "Array and fill value types must match. got {} and {}",
381 array.dtype(),
382 fill_value.dtype()
383 )
384 }
385 let mask = array.validity_mask()?;
386
387 if mask.all_false() {
388 return Ok(
390 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
391 );
392 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
393 let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
396 let non_null_indices = match mask.indices() {
397 AllOr::All => {
398 unreachable!("Mask is mostly null")
400 }
401 AllOr::None => {
402 unreachable!("Mask is mostly null but not all null")
404 }
405 AllOr::Some(values) => {
406 let buffer: Buffer<u32> = values
407 .iter()
408 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
409 .collect();
410
411 buffer.into_array()
412 }
413 };
414
415 return Ok(SparseArray::try_new(
416 non_null_indices,
417 non_null_values,
418 array.len(),
419 Scalar::null(array.dtype().clone()),
420 )?
421 .into_array());
422 }
423
424 let fill = if let Some(fill) = fill_value {
425 fill
426 } else {
427 let (top_pvalue, _) = array
429 .to_primitive()
430 .top_value()?
431 .vortex_expect("Non empty or all null array");
432
433 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
434 };
435
436 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
437 let non_top_mask = Mask::from_buffer(
438 array
439 .to_array()
440 .binary(fill_array.clone(), Operator::NotEq)?
441 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
442 .to_bool()
443 .to_bit_buffer(),
444 );
445
446 let non_top_values = array
447 .filter(non_top_mask.clone())?
448 .to_canonical()?
449 .into_array();
450
451 let indices: Buffer<u64> = match non_top_mask {
452 Mask::AllTrue(count) => {
453 (0u64..count as u64).collect()
455 }
456 Mask::AllFalse(_) => {
457 return Ok(fill_array);
459 }
460 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
461 };
462
463 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
464 .map(|a| a.into_array())
465 }
466}
467
468impl ValidityVTable<SparseVTable> for SparseVTable {
469 fn validity(array: &SparseArray) -> VortexResult<Validity> {
470 let patches = unsafe {
471 Patches::new_unchecked(
472 array.patches.array_len(),
473 array.patches.offset(),
474 array.patches.indices().clone(),
475 array
476 .patches
477 .values()
478 .validity()?
479 .to_array(array.patches.values().len()),
480 array.patches.chunk_offsets().clone(),
481 array.patches.offset_within_chunk(),
482 )
483 };
484
485 Ok(Validity::Array(
486 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
487 .into_array(),
488 ))
489 }
490}
491
492#[cfg(test)]
493mod test {
494 use itertools::Itertools;
495 use vortex_array::DynArray;
496 use vortex_array::IntoArray;
497 use vortex_array::arrays::ConstantArray;
498 use vortex_array::arrays::PrimitiveArray;
499 use vortex_array::assert_arrays_eq;
500 use vortex_array::builtins::ArrayBuiltins;
501 use vortex_array::dtype::DType;
502 use vortex_array::dtype::Nullability;
503 use vortex_array::dtype::PType;
504 use vortex_array::scalar::Scalar;
505 use vortex_array::validity::Validity;
506 use vortex_buffer::buffer;
507 use vortex_error::VortexExpect;
508
509 use super::*;
510
511 fn nullable_fill() -> Scalar {
512 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
513 }
514
515 fn non_nullable_fill() -> Scalar {
516 Scalar::from(42i32)
517 }
518
519 fn sparse_array(fill_value: Scalar) -> ArrayRef {
520 let mut values = buffer![100i32, 200, 300].into_array();
522 values = values.cast(fill_value.dtype().clone()).unwrap();
523
524 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
525 .unwrap()
526 .into_array()
527 }
528
529 #[test]
530 pub fn test_scalar_at() {
531 let array = sparse_array(nullable_fill());
532
533 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
534 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
535 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
536 }
537
538 #[test]
539 #[should_panic(expected = "out of bounds")]
540 fn test_scalar_at_oob() {
541 let array = sparse_array(nullable_fill());
542 array.scalar_at(10).unwrap();
543 }
544
545 #[test]
546 pub fn test_scalar_at_again() {
547 let arr = SparseArray::try_new(
548 ConstantArray::new(10u32, 1).into_array(),
549 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
550 100,
551 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
552 )
553 .unwrap();
554
555 assert_eq!(
556 arr.scalar_at(10)
557 .unwrap()
558 .as_primitive()
559 .typed_value::<u32>(),
560 Some(1234)
561 );
562 assert!(arr.scalar_at(0).unwrap().is_null());
563 assert!(arr.scalar_at(99).unwrap().is_null());
564 }
565
566 #[test]
567 pub fn scalar_at_sliced() {
568 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
569 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
570 }
571
572 #[test]
573 pub fn validity_mask_sliced_null_fill() {
574 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
575 assert_eq!(
576 sliced.validity_mask().unwrap(),
577 Mask::from_iter(vec![true, false, false, true, false])
578 );
579 }
580
581 #[test]
582 pub fn validity_mask_sliced_nonnull_fill() {
583 let sliced = SparseArray::try_new(
584 buffer![2u64, 5, 8].into_array(),
585 ConstantArray::new(
586 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
587 3,
588 )
589 .into_array(),
590 10,
591 Scalar::primitive(1.0f32, Nullability::Nullable),
592 )
593 .unwrap()
594 .slice(2..7)
595 .unwrap();
596
597 assert_eq!(
598 sliced.validity_mask().unwrap(),
599 Mask::from_iter(vec![false, true, true, false, true])
600 );
601 }
602
603 #[test]
604 pub fn scalar_at_sliced_twice() {
605 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
606 assert_eq!(
607 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
608 100
609 );
610
611 let sliced_twice = sliced_once.slice(1..6).unwrap();
612 assert_eq!(
613 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
614 200
615 );
616 }
617
618 #[test]
619 pub fn sparse_validity_mask() {
620 let array = sparse_array(nullable_fill());
621 assert_eq!(
622 array
623 .validity_mask()
624 .unwrap()
625 .to_bit_buffer()
626 .iter()
627 .collect_vec(),
628 [
629 false, false, true, false, false, true, false, false, true, false
630 ]
631 );
632 }
633
634 #[test]
635 fn sparse_validity_mask_non_null_fill() {
636 let array = sparse_array(non_nullable_fill());
637 assert!(array.validity_mask().unwrap().all_true());
638 }
639
640 #[test]
641 #[should_panic]
642 fn test_invalid_length() {
643 let values = buffer![15_u32, 135, 13531, 42].into_array();
644 let indices = buffer![10_u64, 11, 50, 100].into_array();
645
646 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
647 }
648
649 #[test]
650 fn test_valid_length() {
651 let values = buffer![15_u32, 135, 13531, 42].into_array();
652 let indices = buffer![10_u64, 11, 50, 100].into_array();
653
654 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
655 }
656
657 #[test]
658 fn encode_with_nulls() {
659 let original = PrimitiveArray::new(
660 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
661 Validity::from_iter(vec![
662 true, true, false, true, false, true, false, true, true, false, true, false,
663 ]),
664 );
665 let sparse = SparseArray::encode(&original.clone().into_array(), None)
666 .vortex_expect("SparseArray::encode should succeed for test data");
667 assert_eq!(
668 sparse.validity_mask().unwrap(),
669 Mask::from_iter(vec![
670 true, true, false, true, false, true, false, true, true, false, true, false,
671 ])
672 );
673 assert_arrays_eq!(sparse.to_primitive(), original);
674 }
675
676 #[test]
677 fn validity_mask_includes_null_values_when_fill_is_null() {
678 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
679 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
680 .into_array();
681 let array =
682 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
683 let actual = array.validity_mask().unwrap();
684 let expected = Mask::from_iter([
685 true, false, true, false, false, false, false, false, true, false,
686 ]);
687
688 assert_eq!(actual, expected);
689 }
690}