1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use kernel::PARENT_KERNELS;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::Array;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityVTable;
38use vortex_array::vtable::patches_child;
39use vortex_array::vtable::patches_child_name;
40use vortex_array::vtable::patches_nchildren;
41use vortex_buffer::Buffer;
42use vortex_buffer::ByteBufferMut;
43use vortex_error::VortexExpect as _;
44use vortex_error::VortexResult;
45use vortex_error::vortex_bail;
46use vortex_error::vortex_ensure;
47use vortex_error::vortex_ensure_eq;
48use vortex_error::vortex_panic;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_session::VortexSession;
52
53use crate::canonical::execute_sparse;
54use crate::rules::RULES;
55
56mod canonical;
57mod compute;
58mod kernel;
59mod ops;
60mod rules;
61mod slice;
62
63vtable!(Sparse);
64
65#[derive(Debug)]
66pub struct SparseMetadata {
67 patches: PatchesMetadata,
68 fill_value: Scalar,
69}
70
71#[derive(Clone, prost::Message)]
72#[repr(C)]
73pub struct ProstPatchesMetadata {
74 #[prost(message, required, tag = "1")]
75 patches: PatchesMetadata,
76}
77
78impl VTable for Sparse {
79 type Array = SparseArray;
80
81 type Metadata = SparseMetadata;
82 type OperationsVTable = Self;
83 type ValidityVTable = Self;
84
85 fn vtable(_array: &Self::Array) -> &Self {
86 &Sparse
87 }
88
89 fn id(&self) -> ArrayId {
90 Self::ID
91 }
92
93 fn len(array: &SparseArray) -> usize {
94 array.patches.array_len()
95 }
96
97 fn dtype(array: &SparseArray) -> &DType {
98 array.fill_scalar().dtype()
99 }
100
101 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
102 array.stats_set.to_ref(array.as_ref())
103 }
104
105 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
106 array.patches.array_hash(state, precision);
107 array.fill_value.hash(state);
108 }
109
110 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
111 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
112 }
113
114 fn nbuffers(_array: &SparseArray) -> usize {
115 1
116 }
117
118 fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
119 match idx {
120 0 => {
121 let fill_value_buffer =
122 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
123 BufferHandle::new_host(fill_value_buffer)
124 }
125 _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
126 }
127 }
128
129 fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
130 match idx {
131 0 => Some("fill_value".to_string()),
132 _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
133 }
134 }
135
136 fn nchildren(array: &SparseArray) -> usize {
137 patches_nchildren(array.patches())
138 }
139
140 fn child(array: &SparseArray, idx: usize) -> ArrayRef {
141 patches_child(array.patches(), idx)
142 }
143
144 fn child_name(_array: &SparseArray, idx: usize) -> String {
145 patches_child_name(idx).to_string()
146 }
147
148 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
149 let patches = array.patches().to_metadata(array.len(), array.dtype())?;
150
151 Ok(SparseMetadata {
152 patches,
153 fill_value: array.fill_value.clone(),
154 })
155 }
156
157 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
158 let prost_patches = ProstPatchesMetadata {
159 patches: metadata.patches,
160 };
161
162 Ok(Some(prost_patches.encode_to_vec()))
164 }
165
166 fn deserialize(
167 bytes: &[u8],
168 dtype: &DType,
169 _len: usize,
170 buffers: &[BufferHandle],
171 session: &VortexSession,
172 ) -> VortexResult<Self::Metadata> {
173 let prost_patches = ProstPatchesMetadata::decode(bytes)?;
174
175 if buffers.len() != 1 {
178 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
179 }
180 let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
181
182 let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
183 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
184
185 Ok(SparseMetadata {
186 patches: prost_patches.patches,
187 fill_value,
188 })
189 }
190
191 fn build(
192 dtype: &DType,
193 len: usize,
194 metadata: &Self::Metadata,
195 _buffers: &[BufferHandle],
196 children: &dyn ArrayChildren,
197 ) -> VortexResult<SparseArray> {
198 vortex_ensure_eq!(
199 children.len(),
200 2,
201 "SparseArray expects 2 children for sparse encoding, found {}",
202 children.len()
203 );
204
205 let patch_indices = children.get(
206 0,
207 &metadata.patches.indices_dtype()?,
208 metadata.patches.len()?,
209 )?;
210 let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
211
212 SparseArray::try_new_from_patches(
213 Patches::new(
214 len,
215 metadata.patches.offset()?,
216 patch_indices,
217 patch_values,
218 None,
219 )?,
220 metadata.fill_value.clone(),
221 )
222 }
223
224 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
225 vortex_ensure_eq!(
226 children.len(),
227 2,
228 "SparseArray expects 2 children, got {}",
229 children.len()
230 );
231
232 let mut children_iter = children.into_iter();
233 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
234 let patch_values = children_iter.next().vortex_expect("patch_values child");
235
236 array.patches = Patches::new(
237 array.patches.array_len(),
238 array.patches.offset(),
239 patch_indices,
240 patch_values,
241 array.patches.chunk_offsets().clone(),
242 )?;
243
244 Ok(())
245 }
246
247 fn reduce_parent(
248 array: &Array<Self>,
249 parent: &ArrayRef,
250 child_idx: usize,
251 ) -> VortexResult<Option<ArrayRef>> {
252 RULES.evaluate(array, parent, child_idx)
253 }
254
255 fn execute_parent(
256 array: &Array<Self>,
257 parent: &ArrayRef,
258 child_idx: usize,
259 ctx: &mut ExecutionCtx,
260 ) -> VortexResult<Option<ArrayRef>> {
261 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
262 }
263
264 fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
265 execute_sparse(&array, ctx).map(ExecutionResult::done)
266 }
267}
268
269#[derive(Clone, Debug)]
270pub struct SparseArray {
271 patches: Patches,
272 fill_value: Scalar,
273 stats_set: ArrayStats,
274}
275
276#[derive(Clone, Debug)]
277pub struct Sparse;
278
279impl Sparse {
280 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
281}
282
283impl SparseArray {
284 pub fn try_new(
285 indices: ArrayRef,
286 values: ArrayRef,
287 len: usize,
288 fill_value: Scalar,
289 ) -> VortexResult<Self> {
290 vortex_ensure!(
291 indices.len() == values.len(),
292 "Mismatched indices {} and values {} length",
293 indices.len(),
294 values.len()
295 );
296
297 if indices.is_host() {
298 debug_assert_eq!(
299 indices.statistics().compute_is_strict_sorted(),
300 Some(true),
301 "SparseArray: indices must be strict-sorted"
302 );
303
304 if !indices.is_empty() {
306 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
307
308 vortex_ensure!(
309 last_index < len,
310 "Array length was {len} but the last index is {last_index}"
311 );
312 }
313 }
314
315 Ok(Self {
316 patches: Patches::new(len, 0, indices, values, None)?,
318 fill_value,
319 stats_set: Default::default(),
320 })
321 }
322
323 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
325 vortex_ensure!(
326 fill_value.dtype() == patches.values().dtype(),
327 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
328 fill_value,
329 patches.values().dtype(),
330 fill_value.dtype(),
331 );
332
333 Ok(Self {
334 patches,
335 fill_value,
336 stats_set: Default::default(),
337 })
338 }
339
340 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
341 Self {
342 patches,
343 fill_value,
344 stats_set: Default::default(),
345 }
346 }
347
348 #[inline]
349 pub fn patches(&self) -> &Patches {
350 &self.patches
351 }
352
353 #[inline]
354 pub fn resolved_patches(&self) -> VortexResult<Patches> {
355 let patches = self.patches();
356 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
357 let indices = patches.indices().to_array().binary(
358 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
359 Operator::Sub,
360 )?;
361
362 Patches::new(
363 patches.array_len(),
364 0,
365 indices,
366 patches.values().clone(),
367 None,
369 )
370 }
371
372 #[inline]
373 pub fn fill_scalar(&self) -> &Scalar {
374 &self.fill_value
375 }
376
377 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
381 if let Some(fill_value) = fill_value.as_ref()
382 && array.dtype() != fill_value.dtype()
383 {
384 vortex_bail!(
385 "Array and fill value types must match. got {} and {}",
386 array.dtype(),
387 fill_value.dtype()
388 )
389 }
390 let mask = array.validity_mask()?;
391
392 if mask.all_false() {
393 return Ok(
395 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
396 );
397 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
398 let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
401 let non_null_indices = match mask.indices() {
402 AllOr::All => {
403 unreachable!("Mask is mostly null")
405 }
406 AllOr::None => {
407 unreachable!("Mask is mostly null but not all null")
409 }
410 AllOr::Some(values) => {
411 let buffer: Buffer<u32> = values
412 .iter()
413 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
414 .collect();
415
416 buffer.into_array()
417 }
418 };
419
420 return Ok(SparseArray::try_new(
421 non_null_indices,
422 non_null_values,
423 array.len(),
424 Scalar::null(array.dtype().clone()),
425 )?
426 .into_array());
427 }
428
429 let fill = if let Some(fill) = fill_value {
430 fill
431 } else {
432 let (top_pvalue, _) = array
434 .to_primitive()
435 .top_value()?
436 .vortex_expect("Non empty or all null array");
437
438 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
439 };
440
441 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
442 let non_top_mask = Mask::from_buffer(
443 array
444 .to_array()
445 .binary(fill_array.clone(), Operator::NotEq)?
446 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
447 .to_bool()
448 .to_bit_buffer(),
449 );
450
451 let non_top_values = array
452 .filter(non_top_mask.clone())?
453 .to_canonical()?
454 .into_array();
455
456 let indices: Buffer<u64> = match non_top_mask {
457 Mask::AllTrue(count) => {
458 (0u64..count as u64).collect()
460 }
461 Mask::AllFalse(_) => {
462 return Ok(fill_array);
464 }
465 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
466 };
467
468 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
469 .map(|a| a.into_array())
470 }
471}
472
473impl ValidityVTable<Sparse> for Sparse {
474 fn validity(array: &SparseArray) -> VortexResult<Validity> {
475 let patches = unsafe {
476 Patches::new_unchecked(
477 array.patches.array_len(),
478 array.patches.offset(),
479 array.patches.indices().clone(),
480 array
481 .patches
482 .values()
483 .validity()?
484 .to_array(array.patches.values().len()),
485 array.patches.chunk_offsets().clone(),
486 array.patches.offset_within_chunk(),
487 )
488 };
489
490 Ok(Validity::Array(
491 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
492 .into_array(),
493 ))
494 }
495}
496
497#[cfg(test)]
498mod test {
499 use itertools::Itertools;
500 use vortex_array::DynArray;
501 use vortex_array::IntoArray;
502 use vortex_array::arrays::ConstantArray;
503 use vortex_array::arrays::PrimitiveArray;
504 use vortex_array::assert_arrays_eq;
505 use vortex_array::builtins::ArrayBuiltins;
506 use vortex_array::dtype::DType;
507 use vortex_array::dtype::Nullability;
508 use vortex_array::dtype::PType;
509 use vortex_array::scalar::Scalar;
510 use vortex_array::validity::Validity;
511 use vortex_buffer::buffer;
512 use vortex_error::VortexExpect;
513
514 use super::*;
515
516 fn nullable_fill() -> Scalar {
517 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
518 }
519
520 fn non_nullable_fill() -> Scalar {
521 Scalar::from(42i32)
522 }
523
524 fn sparse_array(fill_value: Scalar) -> ArrayRef {
525 let mut values = buffer![100i32, 200, 300].into_array();
527 values = values.cast(fill_value.dtype().clone()).unwrap();
528
529 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
530 .unwrap()
531 .into_array()
532 }
533
534 #[test]
535 pub fn test_scalar_at() {
536 let array = sparse_array(nullable_fill());
537
538 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
539 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
540 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
541 }
542
543 #[test]
544 #[should_panic(expected = "out of bounds")]
545 fn test_scalar_at_oob() {
546 let array = sparse_array(nullable_fill());
547 array.scalar_at(10).unwrap();
548 }
549
550 #[test]
551 pub fn test_scalar_at_again() {
552 let arr = SparseArray::try_new(
553 ConstantArray::new(10u32, 1).into_array(),
554 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
555 100,
556 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
557 )
558 .unwrap();
559
560 assert_eq!(
561 arr.scalar_at(10)
562 .unwrap()
563 .as_primitive()
564 .typed_value::<u32>(),
565 Some(1234)
566 );
567 assert!(arr.scalar_at(0).unwrap().is_null());
568 assert!(arr.scalar_at(99).unwrap().is_null());
569 }
570
571 #[test]
572 pub fn scalar_at_sliced() {
573 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
574 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
575 }
576
577 #[test]
578 pub fn validity_mask_sliced_null_fill() {
579 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
580 assert_eq!(
581 sliced.validity_mask().unwrap(),
582 Mask::from_iter(vec![true, false, false, true, false])
583 );
584 }
585
586 #[test]
587 pub fn validity_mask_sliced_nonnull_fill() {
588 let sliced = SparseArray::try_new(
589 buffer![2u64, 5, 8].into_array(),
590 ConstantArray::new(
591 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
592 3,
593 )
594 .into_array(),
595 10,
596 Scalar::primitive(1.0f32, Nullability::Nullable),
597 )
598 .unwrap()
599 .slice(2..7)
600 .unwrap();
601
602 assert_eq!(
603 sliced.validity_mask().unwrap(),
604 Mask::from_iter(vec![false, true, true, false, true])
605 );
606 }
607
608 #[test]
609 pub fn scalar_at_sliced_twice() {
610 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
611 assert_eq!(
612 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
613 100
614 );
615
616 let sliced_twice = sliced_once.slice(1..6).unwrap();
617 assert_eq!(
618 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
619 200
620 );
621 }
622
623 #[test]
624 pub fn sparse_validity_mask() {
625 let array = sparse_array(nullable_fill());
626 assert_eq!(
627 array
628 .validity_mask()
629 .unwrap()
630 .to_bit_buffer()
631 .iter()
632 .collect_vec(),
633 [
634 false, false, true, false, false, true, false, false, true, false
635 ]
636 );
637 }
638
639 #[test]
640 fn sparse_validity_mask_non_null_fill() {
641 let array = sparse_array(non_nullable_fill());
642 assert!(array.validity_mask().unwrap().all_true());
643 }
644
645 #[test]
646 #[should_panic]
647 fn test_invalid_length() {
648 let values = buffer![15_u32, 135, 13531, 42].into_array();
649 let indices = buffer![10_u64, 11, 50, 100].into_array();
650
651 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
652 }
653
654 #[test]
655 fn test_valid_length() {
656 let values = buffer![15_u32, 135, 13531, 42].into_array();
657 let indices = buffer![10_u64, 11, 50, 100].into_array();
658
659 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
660 }
661
662 #[test]
663 fn encode_with_nulls() {
664 let original = PrimitiveArray::new(
665 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
666 Validity::from_iter(vec![
667 true, true, false, true, false, true, false, true, true, false, true, false,
668 ]),
669 );
670 let sparse = SparseArray::encode(&original.clone().into_array(), None)
671 .vortex_expect("SparseArray::encode should succeed for test data");
672 assert_eq!(
673 sparse.validity_mask().unwrap(),
674 Mask::from_iter(vec![
675 true, true, false, true, false, true, false, true, true, false, true, false,
676 ])
677 );
678 assert_arrays_eq!(sparse.to_primitive(), original);
679 }
680
681 #[test]
682 fn validity_mask_includes_null_values_when_fill_is_null() {
683 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
684 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
685 .into_array();
686 let array =
687 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
688 let actual = array.validity_mask().unwrap();
689 let expected = Mask::from_iter([
690 true, false, true, false, false, false, false, false, true, false,
691 ]);
692
693 assert_eq!(actual, expected);
694 }
695}