1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use kernel::PARENT_KERNELS;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityVTable;
37use vortex_array::vtable::patches_child;
38use vortex_array::vtable::patches_child_name;
39use vortex_array::vtable::patches_nchildren;
40use vortex_buffer::Buffer;
41use vortex_buffer::ByteBufferMut;
42use vortex_error::VortexExpect as _;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_ensure_eq;
47use vortex_error::vortex_panic;
48use vortex_mask::AllOr;
49use vortex_mask::Mask;
50use vortex_session::VortexSession;
51
52use crate::canonical::execute_sparse;
53use crate::rules::RULES;
54
55mod canonical;
56mod compute;
57mod kernel;
58mod ops;
59mod rules;
60mod slice;
61
62vtable!(Sparse);
63
64#[derive(Debug)]
65pub struct SparseMetadata {
66 patches: PatchesMetadata,
67 fill_value: Scalar,
68}
69
70#[derive(Clone, prost::Message)]
71#[repr(C)]
72pub struct ProstPatchesMetadata {
73 #[prost(message, required, tag = "1")]
74 patches: PatchesMetadata,
75}
76
77impl VTable for Sparse {
78 type Array = SparseArray;
79
80 type Metadata = SparseMetadata;
81 type OperationsVTable = Self;
82 type ValidityVTable = Self;
83
84 fn vtable(_array: &Self::Array) -> &Self {
85 &Sparse
86 }
87
88 fn id(&self) -> ArrayId {
89 Self::ID
90 }
91
92 fn len(array: &SparseArray) -> usize {
93 array.patches.array_len()
94 }
95
96 fn dtype(array: &SparseArray) -> &DType {
97 array.fill_scalar().dtype()
98 }
99
100 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
101 array.stats_set.to_ref(array.as_ref())
102 }
103
104 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
105 array.patches.array_hash(state, precision);
106 array.fill_value.hash(state);
107 }
108
109 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
110 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
111 }
112
113 fn nbuffers(_array: &SparseArray) -> usize {
114 1
115 }
116
117 fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
118 match idx {
119 0 => {
120 let fill_value_buffer =
121 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
122 BufferHandle::new_host(fill_value_buffer)
123 }
124 _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
125 }
126 }
127
128 fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
129 match idx {
130 0 => Some("fill_value".to_string()),
131 _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
132 }
133 }
134
135 fn nchildren(array: &SparseArray) -> usize {
136 patches_nchildren(array.patches())
137 }
138
139 fn child(array: &SparseArray, idx: usize) -> ArrayRef {
140 patches_child(array.patches(), idx)
141 }
142
143 fn child_name(_array: &SparseArray, idx: usize) -> String {
144 patches_child_name(idx).to_string()
145 }
146
147 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
148 let patches = array.patches().to_metadata(array.len(), array.dtype())?;
149
150 Ok(SparseMetadata {
151 patches,
152 fill_value: array.fill_value.clone(),
153 })
154 }
155
156 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
157 let prost_patches = ProstPatchesMetadata {
158 patches: metadata.patches,
159 };
160
161 Ok(Some(prost_patches.encode_to_vec()))
163 }
164
165 fn deserialize(
166 bytes: &[u8],
167 dtype: &DType,
168 _len: usize,
169 buffers: &[BufferHandle],
170 session: &VortexSession,
171 ) -> VortexResult<Self::Metadata> {
172 let prost_patches = ProstPatchesMetadata::decode(bytes)?;
173
174 if buffers.len() != 1 {
177 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
178 }
179 let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
180
181 let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
182 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
183
184 Ok(SparseMetadata {
185 patches: prost_patches.patches,
186 fill_value,
187 })
188 }
189
190 fn build(
191 dtype: &DType,
192 len: usize,
193 metadata: &Self::Metadata,
194 _buffers: &[BufferHandle],
195 children: &dyn ArrayChildren,
196 ) -> VortexResult<SparseArray> {
197 vortex_ensure_eq!(
198 children.len(),
199 2,
200 "SparseArray expects 2 children for sparse encoding, found {}",
201 children.len()
202 );
203
204 let patch_indices = children.get(
205 0,
206 &metadata.patches.indices_dtype()?,
207 metadata.patches.len()?,
208 )?;
209 let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
210
211 SparseArray::try_new_from_patches(
212 Patches::new(
213 len,
214 metadata.patches.offset()?,
215 patch_indices,
216 patch_values,
217 None,
218 )?,
219 metadata.fill_value.clone(),
220 )
221 }
222
223 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
224 vortex_ensure_eq!(
225 children.len(),
226 2,
227 "SparseArray expects 2 children, got {}",
228 children.len()
229 );
230
231 let mut children_iter = children.into_iter();
232 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
233 let patch_values = children_iter.next().vortex_expect("patch_values child");
234
235 array.patches = Patches::new(
236 array.patches.array_len(),
237 array.patches.offset(),
238 patch_indices,
239 patch_values,
240 array.patches.chunk_offsets().clone(),
241 )?;
242
243 Ok(())
244 }
245
246 fn reduce_parent(
247 array: &Self::Array,
248 parent: &ArrayRef,
249 child_idx: usize,
250 ) -> VortexResult<Option<ArrayRef>> {
251 RULES.evaluate(array, parent, child_idx)
252 }
253
254 fn execute_parent(
255 array: &Self::Array,
256 parent: &ArrayRef,
257 child_idx: usize,
258 ctx: &mut ExecutionCtx,
259 ) -> VortexResult<Option<ArrayRef>> {
260 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
261 }
262
263 fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
264 execute_sparse(&array, ctx).map(ExecutionResult::done)
265 }
266}
267
268#[derive(Clone, Debug)]
269pub struct SparseArray {
270 patches: Patches,
271 fill_value: Scalar,
272 stats_set: ArrayStats,
273}
274
275#[derive(Clone, Debug)]
276pub struct Sparse;
277
278impl Sparse {
279 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
280}
281
282impl SparseArray {
283 pub fn try_new(
284 indices: ArrayRef,
285 values: ArrayRef,
286 len: usize,
287 fill_value: Scalar,
288 ) -> VortexResult<Self> {
289 vortex_ensure!(
290 indices.len() == values.len(),
291 "Mismatched indices {} and values {} length",
292 indices.len(),
293 values.len()
294 );
295
296 if indices.is_host() {
297 debug_assert_eq!(
298 indices.statistics().compute_is_strict_sorted(),
299 Some(true),
300 "SparseArray: indices must be strict-sorted"
301 );
302
303 if !indices.is_empty() {
305 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
306
307 vortex_ensure!(
308 last_index < len,
309 "Array length was {len} but the last index is {last_index}"
310 );
311 }
312 }
313
314 Ok(Self {
315 patches: Patches::new(len, 0, indices, values, None)?,
317 fill_value,
318 stats_set: Default::default(),
319 })
320 }
321
322 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
324 vortex_ensure!(
325 fill_value.dtype() == patches.values().dtype(),
326 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
327 fill_value,
328 patches.values().dtype(),
329 fill_value.dtype(),
330 );
331
332 Ok(Self {
333 patches,
334 fill_value,
335 stats_set: Default::default(),
336 })
337 }
338
339 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
340 Self {
341 patches,
342 fill_value,
343 stats_set: Default::default(),
344 }
345 }
346
347 #[inline]
348 pub fn patches(&self) -> &Patches {
349 &self.patches
350 }
351
352 #[inline]
353 pub fn resolved_patches(&self) -> VortexResult<Patches> {
354 let patches = self.patches();
355 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
356 let indices = patches.indices().to_array().binary(
357 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
358 Operator::Sub,
359 )?;
360
361 Patches::new(
362 patches.array_len(),
363 0,
364 indices,
365 patches.values().clone(),
366 None,
368 )
369 }
370
371 #[inline]
372 pub fn fill_scalar(&self) -> &Scalar {
373 &self.fill_value
374 }
375
376 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
380 if let Some(fill_value) = fill_value.as_ref()
381 && array.dtype() != fill_value.dtype()
382 {
383 vortex_bail!(
384 "Array and fill value types must match. got {} and {}",
385 array.dtype(),
386 fill_value.dtype()
387 )
388 }
389 let mask = array.validity_mask()?;
390
391 if mask.all_false() {
392 return Ok(
394 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
395 );
396 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
397 let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
400 let non_null_indices = match mask.indices() {
401 AllOr::All => {
402 unreachable!("Mask is mostly null")
404 }
405 AllOr::None => {
406 unreachable!("Mask is mostly null but not all null")
408 }
409 AllOr::Some(values) => {
410 let buffer: Buffer<u32> = values
411 .iter()
412 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
413 .collect();
414
415 buffer.into_array()
416 }
417 };
418
419 return Ok(SparseArray::try_new(
420 non_null_indices,
421 non_null_values,
422 array.len(),
423 Scalar::null(array.dtype().clone()),
424 )?
425 .into_array());
426 }
427
428 let fill = if let Some(fill) = fill_value {
429 fill
430 } else {
431 let (top_pvalue, _) = array
433 .to_primitive()
434 .top_value()?
435 .vortex_expect("Non empty or all null array");
436
437 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
438 };
439
440 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
441 let non_top_mask = Mask::from_buffer(
442 array
443 .to_array()
444 .binary(fill_array.clone(), Operator::NotEq)?
445 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
446 .to_bool()
447 .to_bit_buffer(),
448 );
449
450 let non_top_values = array
451 .filter(non_top_mask.clone())?
452 .to_canonical()?
453 .into_array();
454
455 let indices: Buffer<u64> = match non_top_mask {
456 Mask::AllTrue(count) => {
457 (0u64..count as u64).collect()
459 }
460 Mask::AllFalse(_) => {
461 return Ok(fill_array);
463 }
464 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
465 };
466
467 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
468 .map(|a| a.into_array())
469 }
470}
471
472impl ValidityVTable<Sparse> for Sparse {
473 fn validity(array: &SparseArray) -> VortexResult<Validity> {
474 let patches = unsafe {
475 Patches::new_unchecked(
476 array.patches.array_len(),
477 array.patches.offset(),
478 array.patches.indices().clone(),
479 array
480 .patches
481 .values()
482 .validity()?
483 .to_array(array.patches.values().len()),
484 array.patches.chunk_offsets().clone(),
485 array.patches.offset_within_chunk(),
486 )
487 };
488
489 Ok(Validity::Array(
490 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
491 .into_array(),
492 ))
493 }
494}
495
496#[cfg(test)]
497mod test {
498 use itertools::Itertools;
499 use vortex_array::DynArray;
500 use vortex_array::IntoArray;
501 use vortex_array::arrays::ConstantArray;
502 use vortex_array::arrays::PrimitiveArray;
503 use vortex_array::assert_arrays_eq;
504 use vortex_array::builtins::ArrayBuiltins;
505 use vortex_array::dtype::DType;
506 use vortex_array::dtype::Nullability;
507 use vortex_array::dtype::PType;
508 use vortex_array::scalar::Scalar;
509 use vortex_array::validity::Validity;
510 use vortex_buffer::buffer;
511 use vortex_error::VortexExpect;
512
513 use super::*;
514
515 fn nullable_fill() -> Scalar {
516 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
517 }
518
519 fn non_nullable_fill() -> Scalar {
520 Scalar::from(42i32)
521 }
522
523 fn sparse_array(fill_value: Scalar) -> ArrayRef {
524 let mut values = buffer![100i32, 200, 300].into_array();
526 values = values.cast(fill_value.dtype().clone()).unwrap();
527
528 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
529 .unwrap()
530 .into_array()
531 }
532
533 #[test]
534 pub fn test_scalar_at() {
535 let array = sparse_array(nullable_fill());
536
537 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
538 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
539 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
540 }
541
542 #[test]
543 #[should_panic(expected = "out of bounds")]
544 fn test_scalar_at_oob() {
545 let array = sparse_array(nullable_fill());
546 array.scalar_at(10).unwrap();
547 }
548
549 #[test]
550 pub fn test_scalar_at_again() {
551 let arr = SparseArray::try_new(
552 ConstantArray::new(10u32, 1).into_array(),
553 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
554 100,
555 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
556 )
557 .unwrap();
558
559 assert_eq!(
560 arr.scalar_at(10)
561 .unwrap()
562 .as_primitive()
563 .typed_value::<u32>(),
564 Some(1234)
565 );
566 assert!(arr.scalar_at(0).unwrap().is_null());
567 assert!(arr.scalar_at(99).unwrap().is_null());
568 }
569
570 #[test]
571 pub fn scalar_at_sliced() {
572 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
573 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
574 }
575
576 #[test]
577 pub fn validity_mask_sliced_null_fill() {
578 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
579 assert_eq!(
580 sliced.validity_mask().unwrap(),
581 Mask::from_iter(vec![true, false, false, true, false])
582 );
583 }
584
585 #[test]
586 pub fn validity_mask_sliced_nonnull_fill() {
587 let sliced = SparseArray::try_new(
588 buffer![2u64, 5, 8].into_array(),
589 ConstantArray::new(
590 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
591 3,
592 )
593 .into_array(),
594 10,
595 Scalar::primitive(1.0f32, Nullability::Nullable),
596 )
597 .unwrap()
598 .slice(2..7)
599 .unwrap();
600
601 assert_eq!(
602 sliced.validity_mask().unwrap(),
603 Mask::from_iter(vec![false, true, true, false, true])
604 );
605 }
606
607 #[test]
608 pub fn scalar_at_sliced_twice() {
609 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
610 assert_eq!(
611 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
612 100
613 );
614
615 let sliced_twice = sliced_once.slice(1..6).unwrap();
616 assert_eq!(
617 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
618 200
619 );
620 }
621
622 #[test]
623 pub fn sparse_validity_mask() {
624 let array = sparse_array(nullable_fill());
625 assert_eq!(
626 array
627 .validity_mask()
628 .unwrap()
629 .to_bit_buffer()
630 .iter()
631 .collect_vec(),
632 [
633 false, false, true, false, false, true, false, false, true, false
634 ]
635 );
636 }
637
638 #[test]
639 fn sparse_validity_mask_non_null_fill() {
640 let array = sparse_array(non_nullable_fill());
641 assert!(array.validity_mask().unwrap().all_true());
642 }
643
644 #[test]
645 #[should_panic]
646 fn test_invalid_length() {
647 let values = buffer![15_u32, 135, 13531, 42].into_array();
648 let indices = buffer![10_u64, 11, 50, 100].into_array();
649
650 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
651 }
652
653 #[test]
654 fn test_valid_length() {
655 let values = buffer![15_u32, 135, 13531, 42].into_array();
656 let indices = buffer![10_u64, 11, 50, 100].into_array();
657
658 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
659 }
660
661 #[test]
662 fn encode_with_nulls() {
663 let original = PrimitiveArray::new(
664 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
665 Validity::from_iter(vec![
666 true, true, false, true, false, true, false, true, true, false, true, false,
667 ]),
668 );
669 let sparse = SparseArray::encode(&original.clone().into_array(), None)
670 .vortex_expect("SparseArray::encode should succeed for test data");
671 assert_eq!(
672 sparse.validity_mask().unwrap(),
673 Mask::from_iter(vec![
674 true, true, false, true, false, true, false, true, true, false, true, false,
675 ])
676 );
677 assert_arrays_eq!(sparse.to_primitive(), original);
678 }
679
680 #[test]
681 fn validity_mask_includes_null_values_when_fill_is_null() {
682 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
683 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
684 .into_array();
685 let array =
686 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
687 let actual = array.validity_mask().unwrap();
688 let expected = Mask::from_iter([
689 true, false, true, false, false, false, false, false, true, false,
690 ]);
691
692 assert_eq!(actual, expected);
693 }
694}