1use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DeserializeMetadata;
14use vortex_array::ExecutionCtx;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityVTable;
37use vortex_array::vtable::patches_child;
38use vortex_array::vtable::patches_child_name;
39use vortex_array::vtable::patches_nchildren;
40use vortex_buffer::Buffer;
41use vortex_buffer::ByteBufferMut;
42use vortex_error::VortexExpect as _;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_ensure_eq;
47use vortex_error::vortex_panic;
48use vortex_mask::AllOr;
49use vortex_mask::Mask;
50use vortex_session::VortexSession;
51
52use crate::canonical::execute_sparse;
53use crate::rules::RULES;
54
55mod canonical;
56mod compute;
57mod kernel;
58mod ops;
59mod rules;
60mod slice;
61
62vtable!(Sparse);
63
64#[derive(Debug)]
65pub struct SparseMetadata {
66 patches: PatchesMetadata,
67 fill_value: Scalar,
68}
69
70#[derive(Clone, prost::Message)]
71#[repr(C)]
72pub struct ProstPatchesMetadata {
73 #[prost(message, required, tag = "1")]
74 patches: PatchesMetadata,
75}
76
77impl VTable for SparseVTable {
78 type Array = SparseArray;
79
80 type Metadata = SparseMetadata;
81 type OperationsVTable = Self;
82 type ValidityVTable = Self;
83
84 fn id(_array: &Self::Array) -> ArrayId {
85 Self::ID
86 }
87
88 fn len(array: &SparseArray) -> usize {
89 array.patches.array_len()
90 }
91
92 fn dtype(array: &SparseArray) -> &DType {
93 array.fill_scalar().dtype()
94 }
95
96 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
97 array.stats_set.to_ref(array.as_ref())
98 }
99
100 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
101 array.patches.array_hash(state, precision);
102 array.fill_value.hash(state);
103 }
104
105 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
106 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
107 }
108
109 fn nbuffers(_array: &SparseArray) -> usize {
110 1
111 }
112
113 fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
114 match idx {
115 0 => {
116 let fill_value_buffer =
117 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
118 BufferHandle::new_host(fill_value_buffer)
119 }
120 _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
121 }
122 }
123
124 fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
125 match idx {
126 0 => Some("fill_value".to_string()),
127 _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
128 }
129 }
130
131 fn nchildren(array: &SparseArray) -> usize {
132 patches_nchildren(array.patches())
133 }
134
135 fn child(array: &SparseArray, idx: usize) -> ArrayRef {
136 patches_child(array.patches(), idx)
137 }
138
139 fn child_name(_array: &SparseArray, idx: usize) -> String {
140 patches_child_name(idx).to_string()
141 }
142
143 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
144 let patches = array.patches().to_metadata(array.len(), array.dtype())?;
145
146 Ok(SparseMetadata {
147 patches,
148 fill_value: array.fill_value.clone(),
149 })
150 }
151
152 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
153 let prost_patches = ProstPatchesMetadata {
154 patches: metadata.patches,
155 };
156
157 Ok(Some(prost_patches.encode_to_vec()))
159 }
160
161 fn deserialize(
162 bytes: &[u8],
163 dtype: &DType,
164 _len: usize,
165 buffers: &[BufferHandle],
166 _session: &VortexSession,
167 ) -> VortexResult<Self::Metadata> {
168 let prost_patches =
169 <ProstMetadata<ProstPatchesMetadata> as DeserializeMetadata>::deserialize(bytes)?;
170
171 if buffers.len() != 1 {
174 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
175 }
176 let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
177
178 let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype)?;
179 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
180
181 Ok(SparseMetadata {
182 patches: prost_patches.patches,
183 fill_value,
184 })
185 }
186
187 fn build(
188 dtype: &DType,
189 len: usize,
190 metadata: &Self::Metadata,
191 _buffers: &[BufferHandle],
192 children: &dyn ArrayChildren,
193 ) -> VortexResult<SparseArray> {
194 vortex_ensure_eq!(
195 children.len(),
196 2,
197 "SparseArray expects 2 children for sparse encoding, found {}",
198 children.len()
199 );
200 vortex_ensure_eq!(
201 metadata.patches.offset()?,
202 0,
203 "Patches must start at offset 0"
204 );
205
206 let patch_indices = children.get(
207 0,
208 &metadata.patches.indices_dtype()?,
209 metadata.patches.len()?,
210 )?;
211 let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
212
213 SparseArray::try_new(
214 patch_indices,
215 patch_values,
216 len,
217 metadata.fill_value.clone(),
218 )
219 }
220
221 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
222 vortex_ensure_eq!(
223 children.len(),
224 2,
225 "SparseArray expects 2 children, got {}",
226 children.len()
227 );
228
229 let mut children_iter = children.into_iter();
230 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
231 let patch_values = children_iter.next().vortex_expect("patch_values child");
232
233 array.patches = Patches::new(
234 array.patches.array_len(),
235 array.patches.offset(),
236 patch_indices,
237 patch_values,
238 array.patches.chunk_offsets().clone(),
239 )?;
240
241 Ok(())
242 }
243
244 fn reduce_parent(
245 array: &Self::Array,
246 parent: &ArrayRef,
247 child_idx: usize,
248 ) -> VortexResult<Option<ArrayRef>> {
249 RULES.evaluate(array, parent, child_idx)
250 }
251
252 fn execute_parent(
253 array: &Self::Array,
254 parent: &ArrayRef,
255 child_idx: usize,
256 ctx: &mut ExecutionCtx,
257 ) -> VortexResult<Option<ArrayRef>> {
258 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
259 }
260
261 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
262 execute_sparse(array, ctx)
263 }
264}
265
266#[derive(Clone, Debug)]
267pub struct SparseArray {
268 patches: Patches,
269 fill_value: Scalar,
270 stats_set: ArrayStats,
271}
272
273#[derive(Debug)]
274pub struct SparseVTable;
275
276impl SparseVTable {
277 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
278}
279
280impl SparseArray {
281 pub fn try_new(
282 indices: ArrayRef,
283 values: ArrayRef,
284 len: usize,
285 fill_value: Scalar,
286 ) -> VortexResult<Self> {
287 vortex_ensure!(
288 indices.len() == values.len(),
289 "Mismatched indices {} and values {} length",
290 indices.len(),
291 values.len()
292 );
293
294 if indices.is_host() {
295 debug_assert_eq!(
296 indices.statistics().compute_is_strict_sorted(),
297 Some(true),
298 "SparseArray: indices must be strict-sorted"
299 );
300
301 if !indices.is_empty() {
303 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
304
305 vortex_ensure!(
306 last_index < len,
307 "Array length was {len} but the last index is {last_index}"
308 );
309 }
310 }
311
312 Ok(Self {
313 patches: Patches::new(len, 0, indices, values, None)?,
315 fill_value,
316 stats_set: Default::default(),
317 })
318 }
319
320 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
322 vortex_ensure!(
323 fill_value.dtype() == patches.values().dtype(),
324 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
325 fill_value,
326 patches.values().dtype(),
327 fill_value.dtype(),
328 );
329
330 Ok(Self {
331 patches,
332 fill_value,
333 stats_set: Default::default(),
334 })
335 }
336
337 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
338 Self {
339 patches,
340 fill_value,
341 stats_set: Default::default(),
342 }
343 }
344
345 #[inline]
346 pub fn patches(&self) -> &Patches {
347 &self.patches
348 }
349
350 #[inline]
351 pub fn resolved_patches(&self) -> VortexResult<Patches> {
352 let patches = self.patches();
353 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
354 let indices = patches.indices().to_array().binary(
355 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
356 Operator::Sub,
357 )?;
358
359 Patches::new(
360 patches.array_len(),
361 0,
362 indices,
363 patches.values().clone(),
364 None,
366 )
367 }
368
369 #[inline]
370 pub fn fill_scalar(&self) -> &Scalar {
371 &self.fill_value
372 }
373
374 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
378 if let Some(fill_value) = fill_value.as_ref()
379 && array.dtype() != fill_value.dtype()
380 {
381 vortex_bail!(
382 "Array and fill value types must match. got {} and {}",
383 array.dtype(),
384 fill_value.dtype()
385 )
386 }
387 let mask = array.validity_mask()?;
388
389 if mask.all_false() {
390 return Ok(
392 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
393 );
394 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
395 let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
398 let non_null_indices = match mask.indices() {
399 AllOr::All => {
400 unreachable!("Mask is mostly null")
402 }
403 AllOr::None => {
404 unreachable!("Mask is mostly null but not all null")
406 }
407 AllOr::Some(values) => {
408 let buffer: Buffer<u32> = values
409 .iter()
410 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
411 .collect();
412
413 buffer.into_array()
414 }
415 };
416
417 return Ok(SparseArray::try_new(
418 non_null_indices,
419 non_null_values,
420 array.len(),
421 Scalar::null(array.dtype().clone()),
422 )?
423 .into_array());
424 }
425
426 let fill = if let Some(fill) = fill_value {
427 fill
428 } else {
429 let (top_pvalue, _) = array
431 .to_primitive()
432 .top_value()?
433 .vortex_expect("Non empty or all null array");
434
435 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
436 };
437
438 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
439 let non_top_mask = Mask::from_buffer(
440 array
441 .to_array()
442 .binary(fill_array.clone(), Operator::NotEq)?
443 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
444 .to_bool()
445 .to_bit_buffer(),
446 );
447
448 let non_top_values = array
449 .filter(non_top_mask.clone())?
450 .to_canonical()?
451 .into_array();
452
453 let indices: Buffer<u64> = match non_top_mask {
454 Mask::AllTrue(count) => {
455 (0u64..count as u64).collect()
457 }
458 Mask::AllFalse(_) => {
459 return Ok(fill_array);
461 }
462 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
463 };
464
465 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
466 .map(|a| a.into_array())
467 }
468}
469
470impl ValidityVTable<SparseVTable> for SparseVTable {
471 fn validity(array: &SparseArray) -> VortexResult<Validity> {
472 let patches = unsafe {
473 Patches::new_unchecked(
474 array.patches.array_len(),
475 array.patches.offset(),
476 array.patches.indices().clone(),
477 array
478 .patches
479 .values()
480 .validity()?
481 .to_array(array.patches.values().len()),
482 array.patches.chunk_offsets().clone(),
483 array.patches.offset_within_chunk(),
484 )
485 };
486
487 Ok(Validity::Array(
488 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
489 .into_array(),
490 ))
491 }
492}
493
494#[cfg(test)]
495mod test {
496 use itertools::Itertools;
497 use vortex_array::IntoArray;
498 use vortex_array::arrays::ConstantArray;
499 use vortex_array::arrays::PrimitiveArray;
500 use vortex_array::assert_arrays_eq;
501 use vortex_array::builtins::ArrayBuiltins;
502 use vortex_array::dtype::DType;
503 use vortex_array::dtype::Nullability;
504 use vortex_array::dtype::PType;
505 use vortex_array::scalar::Scalar;
506 use vortex_array::validity::Validity;
507 use vortex_buffer::buffer;
508 use vortex_error::VortexExpect;
509
510 use super::*;
511
512 fn nullable_fill() -> Scalar {
513 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
514 }
515
516 fn non_nullable_fill() -> Scalar {
517 Scalar::from(42i32)
518 }
519
520 fn sparse_array(fill_value: Scalar) -> ArrayRef {
521 let mut values = buffer![100i32, 200, 300].into_array();
523 values = values.cast(fill_value.dtype().clone()).unwrap();
524
525 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
526 .unwrap()
527 .into_array()
528 }
529
530 #[test]
531 pub fn test_scalar_at() {
532 let array = sparse_array(nullable_fill());
533
534 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
535 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
536 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
537 }
538
539 #[test]
540 #[should_panic(expected = "out of bounds")]
541 fn test_scalar_at_oob() {
542 let array = sparse_array(nullable_fill());
543 array.scalar_at(10).unwrap();
544 }
545
546 #[test]
547 pub fn test_scalar_at_again() {
548 let arr = SparseArray::try_new(
549 ConstantArray::new(10u32, 1).into_array(),
550 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
551 100,
552 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
553 )
554 .unwrap();
555
556 assert_eq!(
557 arr.scalar_at(10)
558 .unwrap()
559 .as_primitive()
560 .typed_value::<u32>(),
561 Some(1234)
562 );
563 assert!(arr.scalar_at(0).unwrap().is_null());
564 assert!(arr.scalar_at(99).unwrap().is_null());
565 }
566
567 #[test]
568 pub fn scalar_at_sliced() {
569 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
570 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
571 }
572
573 #[test]
574 pub fn validity_mask_sliced_null_fill() {
575 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
576 assert_eq!(
577 sliced.validity_mask().unwrap(),
578 Mask::from_iter(vec![true, false, false, true, false])
579 );
580 }
581
582 #[test]
583 pub fn validity_mask_sliced_nonnull_fill() {
584 let sliced = SparseArray::try_new(
585 buffer![2u64, 5, 8].into_array(),
586 ConstantArray::new(
587 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
588 3,
589 )
590 .into_array(),
591 10,
592 Scalar::primitive(1.0f32, Nullability::Nullable),
593 )
594 .unwrap()
595 .slice(2..7)
596 .unwrap();
597
598 assert_eq!(
599 sliced.validity_mask().unwrap(),
600 Mask::from_iter(vec![false, true, true, false, true])
601 );
602 }
603
604 #[test]
605 pub fn scalar_at_sliced_twice() {
606 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
607 assert_eq!(
608 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
609 100
610 );
611
612 let sliced_twice = sliced_once.slice(1..6).unwrap();
613 assert_eq!(
614 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
615 200
616 );
617 }
618
619 #[test]
620 pub fn sparse_validity_mask() {
621 let array = sparse_array(nullable_fill());
622 assert_eq!(
623 array
624 .validity_mask()
625 .unwrap()
626 .to_bit_buffer()
627 .iter()
628 .collect_vec(),
629 [
630 false, false, true, false, false, true, false, false, true, false
631 ]
632 );
633 }
634
635 #[test]
636 fn sparse_validity_mask_non_null_fill() {
637 let array = sparse_array(non_nullable_fill());
638 assert!(array.validity_mask().unwrap().all_true());
639 }
640
641 #[test]
642 #[should_panic]
643 fn test_invalid_length() {
644 let values = buffer![15_u32, 135, 13531, 42].into_array();
645 let indices = buffer![10_u64, 11, 50, 100].into_array();
646
647 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
648 }
649
650 #[test]
651 fn test_valid_length() {
652 let values = buffer![15_u32, 135, 13531, 42].into_array();
653 let indices = buffer![10_u64, 11, 50, 100].into_array();
654
655 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
656 }
657
658 #[test]
659 fn encode_with_nulls() {
660 let original = PrimitiveArray::new(
661 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
662 Validity::from_iter(vec![
663 true, true, false, true, false, true, false, true, true, false, true, false,
664 ]),
665 );
666 let sparse = SparseArray::encode(&original.clone().into_array(), None)
667 .vortex_expect("SparseArray::encode should succeed for test data");
668 assert_eq!(
669 sparse.validity_mask().unwrap(),
670 Mask::from_iter(vec![
671 true, true, false, true, false, true, false, true, true, false, true, false,
672 ])
673 );
674 assert_arrays_eq!(sparse.to_primitive(), original);
675 }
676
677 #[test]
678 fn validity_mask_includes_null_values_when_fill_is_null() {
679 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
680 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
681 .into_array();
682 let array =
683 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
684 let actual = array.validity_mask().unwrap();
685 let expected = Mask::from_iter([
686 true, false, true, false, false, false, false, false, true, false,
687 ]);
688
689 assert_eq!(actual, expected);
690 }
691}