1use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::Array;
10use vortex_array::ArrayBufferVisitor;
11use vortex_array::ArrayChildVisitor;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayRef;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::ConstantArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::builtins::ArrayBuiltins;
23use vortex_array::compute::Operator;
24use vortex_array::compute::compare;
25use vortex_array::compute::filter;
26use vortex_array::compute::sub_scalar;
27use vortex_array::patches::Patches;
28use vortex_array::patches::PatchesMetadata;
29use vortex_array::scalar::Scalar;
30use vortex_array::scalar::ScalarValue;
31use vortex_array::serde::ArrayChildren;
32use vortex_array::stats::ArrayStats;
33use vortex_array::stats::StatsSetRef;
34use vortex_array::validity::Validity;
35use vortex_array::vtable;
36use vortex_array::vtable::ArrayId;
37use vortex_array::vtable::BaseArrayVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_array::vtable::VisitorVTable;
41use vortex_buffer::Buffer;
42use vortex_buffer::ByteBufferMut;
43use vortex_dtype::DType;
44use vortex_dtype::Nullability;
45use vortex_error::VortexExpect as _;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_session::VortexSession;
52
53use crate::canonical::execute_sparse;
54use crate::rules::RULES;
55
56mod canonical;
57mod compute;
58mod kernel;
59mod ops;
60mod rules;
61mod slice;
62
63vtable!(Sparse);
64
65#[derive(Clone, prost::Message)]
66#[repr(C)]
67pub struct SparseMetadata {
68 #[prost(message, required, tag = "1")]
69 patches: PatchesMetadata,
70}
71
72impl VTable for SparseVTable {
73 type Array = SparseArray;
74
75 type Metadata = ProstMetadata<SparseMetadata>;
76
77 type ArrayVTable = Self;
78 type OperationsVTable = Self;
79 type ValidityVTable = Self;
80 type VisitorVTable = Self;
81
82 fn id(_array: &Self::Array) -> ArrayId {
83 Self::ID
84 }
85
86 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
87 Ok(ProstMetadata(SparseMetadata {
88 patches: array.patches().to_metadata(array.len(), array.dtype())?,
89 }))
90 }
91
92 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
93 Ok(Some(metadata.0.encode_to_vec()))
94 }
95
96 fn deserialize(
97 bytes: &[u8],
98 _dtype: &DType,
99 _len: usize,
100 _buffers: &[BufferHandle],
101 _session: &VortexSession,
102 ) -> VortexResult<Self::Metadata> {
103 Ok(ProstMetadata(SparseMetadata::decode(bytes)?))
104 }
105
106 fn build(
107 dtype: &DType,
108 len: usize,
109 metadata: &Self::Metadata,
110 buffers: &[BufferHandle],
111 children: &dyn ArrayChildren,
112 ) -> VortexResult<SparseArray> {
113 if children.len() != 2 {
114 vortex_bail!(
115 "Expected 2 children for sparse encoding, found {}",
116 children.len()
117 )
118 }
119 vortex_ensure!(
120 metadata.0.patches.offset()? == 0,
121 "Patches must start at offset 0"
122 );
123
124 let patch_indices = children.get(
125 0,
126 &metadata.0.patches.indices_dtype()?,
127 metadata.0.patches.len()?,
128 )?;
129 let patch_values = children.get(1, dtype, metadata.0.patches.len()?)?;
130
131 if buffers.len() != 1 {
132 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
133 }
134
135 let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
136 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
137
138 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
139
140 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
141 }
142
143 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
144 vortex_ensure!(
145 children.len() == 2,
146 "SparseArray expects 2 children, got {}",
147 children.len()
148 );
149
150 let mut children_iter = children.into_iter();
151 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
152 let patch_values = children_iter.next().vortex_expect("patch_values child");
153
154 array.patches = Patches::new(
155 array.patches.array_len(),
156 array.patches.offset(),
157 patch_indices,
158 patch_values,
159 array.patches.chunk_offsets().clone(),
160 )?;
161
162 Ok(())
163 }
164
165 fn reduce_parent(
166 array: &Self::Array,
167 parent: &ArrayRef,
168 child_idx: usize,
169 ) -> VortexResult<Option<ArrayRef>> {
170 RULES.evaluate(array, parent, child_idx)
171 }
172
173 fn execute_parent(
174 array: &Self::Array,
175 parent: &ArrayRef,
176 child_idx: usize,
177 ctx: &mut ExecutionCtx,
178 ) -> VortexResult<Option<ArrayRef>> {
179 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
180 }
181
182 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
183 execute_sparse(array)
184 }
185}
186
187#[derive(Clone, Debug)]
188pub struct SparseArray {
189 patches: Patches,
190 fill_value: Scalar,
191 stats_set: ArrayStats,
192}
193
194#[derive(Debug)]
195pub struct SparseVTable;
196
197impl SparseVTable {
198 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
199}
200
201impl SparseArray {
202 pub fn try_new(
203 indices: ArrayRef,
204 values: ArrayRef,
205 len: usize,
206 fill_value: Scalar,
207 ) -> VortexResult<Self> {
208 vortex_ensure!(
209 indices.len() == values.len(),
210 "Mismatched indices {} and values {} length",
211 indices.len(),
212 values.len()
213 );
214
215 if indices.is_host() {
216 debug_assert_eq!(
217 indices.statistics().compute_is_strict_sorted(),
218 Some(true),
219 "SparseArray: indices must be strict-sorted"
220 );
221
222 if !indices.is_empty() {
224 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
225
226 vortex_ensure!(
227 last_index < len,
228 "Array length was {len} but the last index is {last_index}"
229 );
230 }
231 }
232
233 Ok(Self {
234 patches: Patches::new(len, 0, indices, values, None)?,
236 fill_value,
237 stats_set: Default::default(),
238 })
239 }
240
241 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
243 vortex_ensure!(
244 fill_value.dtype() == patches.values().dtype(),
245 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
246 fill_value,
247 patches.values().dtype(),
248 fill_value.dtype(),
249 );
250
251 Ok(Self {
252 patches,
253 fill_value,
254 stats_set: Default::default(),
255 })
256 }
257
258 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
259 Self {
260 patches,
261 fill_value,
262 stats_set: Default::default(),
263 }
264 }
265
266 #[inline]
267 pub fn patches(&self) -> &Patches {
268 &self.patches
269 }
270
271 #[inline]
272 pub fn resolved_patches(&self) -> VortexResult<Patches> {
273 let patches = self.patches();
274 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
275 let indices = sub_scalar(patches.indices(), indices_offset)?;
276
277 Patches::new(
278 patches.array_len(),
279 0,
280 indices,
281 patches.values().clone(),
282 None,
284 )
285 }
286
287 #[inline]
288 pub fn fill_scalar(&self) -> &Scalar {
289 &self.fill_value
290 }
291
292 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
296 if let Some(fill_value) = fill_value.as_ref()
297 && array.dtype() != fill_value.dtype()
298 {
299 vortex_bail!(
300 "Array and fill value types must match. got {} and {}",
301 array.dtype(),
302 fill_value.dtype()
303 )
304 }
305 let mask = array.validity_mask()?;
306
307 if mask.all_false() {
308 return Ok(
310 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
311 );
312 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
313 let non_null_values = filter(array, &mask)?;
315 let non_null_indices = match mask.indices() {
316 AllOr::All => {
317 unreachable!("Mask is mostly null")
319 }
320 AllOr::None => {
321 unreachable!("Mask is mostly null but not all null")
323 }
324 AllOr::Some(values) => {
325 let buffer: Buffer<u32> = values
326 .iter()
327 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
328 .collect();
329
330 buffer.into_array()
331 }
332 };
333
334 return Ok(SparseArray::try_new(
335 non_null_indices,
336 non_null_values,
337 array.len(),
338 Scalar::null(array.dtype().clone()),
339 )?
340 .into_array());
341 }
342
343 let fill = if let Some(fill) = fill_value {
344 fill
345 } else {
346 let (top_pvalue, _) = array
348 .to_primitive()
349 .top_value()?
350 .vortex_expect("Non empty or all null array");
351
352 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
353 };
354
355 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
356 let non_top_mask = Mask::from_buffer(
357 compare(array, &fill_array, Operator::NotEq)?
358 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
359 .to_bool()
360 .to_bit_buffer(),
361 );
362
363 let non_top_values = filter(array, &non_top_mask)?;
364
365 let indices: Buffer<u64> = match non_top_mask {
366 Mask::AllTrue(count) => {
367 (0u64..count as u64).collect()
369 }
370 Mask::AllFalse(_) => {
371 return Ok(fill_array);
373 }
374 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
375 };
376
377 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
378 .map(|a| a.into_array())
379 }
380}
381
382impl BaseArrayVTable<SparseVTable> for SparseVTable {
383 fn len(array: &SparseArray) -> usize {
384 array.patches.array_len()
385 }
386
387 fn dtype(array: &SparseArray) -> &DType {
388 array.fill_scalar().dtype()
389 }
390
391 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
392 array.stats_set.to_ref(array.as_ref())
393 }
394
395 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
396 array.patches.array_hash(state, precision);
397 array.fill_value.hash(state);
398 }
399
400 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
401 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
402 }
403}
404
405impl ValidityVTable<SparseVTable> for SparseVTable {
406 fn validity(array: &SparseArray) -> VortexResult<Validity> {
407 let patches = unsafe {
408 Patches::new_unchecked(
409 array.patches.array_len(),
410 array.patches.offset(),
411 array.patches.indices().clone(),
412 array
413 .patches
414 .values()
415 .validity()?
416 .to_array(array.patches.values().len()),
417 array.patches.chunk_offsets().clone(),
418 array.patches.offset_within_chunk(),
419 )
420 };
421
422 Ok(Validity::Array(
423 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
424 .into_array(),
425 ))
426 }
427}
428
429impl VisitorVTable<SparseVTable> for SparseVTable {
430 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
431 let fill_value_buffer =
432 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
433 visitor.visit_buffer_handle("fill_value", &BufferHandle::new_host(fill_value_buffer));
434 }
435
436 fn nbuffers(_array: &SparseArray) -> usize {
437 1
438 }
439
440 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
441 visitor.visit_patches(array.patches())
442 }
443
444 fn nchildren(array: &SparseArray) -> usize {
445 2 + array.patches().chunk_offsets().is_some() as usize
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use itertools::Itertools;
453 use vortex_array::IntoArray;
454 use vortex_array::arrays::ConstantArray;
455 use vortex_array::arrays::PrimitiveArray;
456 use vortex_array::assert_arrays_eq;
457 use vortex_array::builtins::ArrayBuiltins;
458 use vortex_array::scalar::Scalar;
459 use vortex_array::validity::Validity;
460 use vortex_buffer::buffer;
461 use vortex_dtype::DType;
462 use vortex_dtype::Nullability;
463 use vortex_dtype::PType;
464 use vortex_error::VortexExpect;
465
466 use super::*;
467
468 fn nullable_fill() -> Scalar {
469 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
470 }
471
472 fn non_nullable_fill() -> Scalar {
473 Scalar::from(42i32)
474 }
475
476 fn sparse_array(fill_value: Scalar) -> ArrayRef {
477 let mut values = buffer![100i32, 200, 300].into_array();
479 values = values.cast(fill_value.dtype().clone()).unwrap();
480
481 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
482 .unwrap()
483 .into_array()
484 }
485
486 #[test]
487 pub fn test_scalar_at() {
488 let array = sparse_array(nullable_fill());
489
490 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
491 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
492 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
493 }
494
495 #[test]
496 #[should_panic(expected = "out of bounds")]
497 fn test_scalar_at_oob() {
498 let array = sparse_array(nullable_fill());
499 array.scalar_at(10).unwrap();
500 }
501
502 #[test]
503 pub fn test_scalar_at_again() {
504 let arr = SparseArray::try_new(
505 ConstantArray::new(10u32, 1).into_array(),
506 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
507 100,
508 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
509 )
510 .unwrap();
511
512 assert_eq!(
513 arr.scalar_at(10)
514 .unwrap()
515 .as_primitive()
516 .typed_value::<u32>(),
517 Some(1234)
518 );
519 assert!(arr.scalar_at(0).unwrap().is_null());
520 assert!(arr.scalar_at(99).unwrap().is_null());
521 }
522
523 #[test]
524 pub fn scalar_at_sliced() {
525 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
526 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
527 }
528
529 #[test]
530 pub fn validity_mask_sliced_null_fill() {
531 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
532 assert_eq!(
533 sliced.validity_mask().unwrap(),
534 Mask::from_iter(vec![true, false, false, true, false])
535 );
536 }
537
538 #[test]
539 pub fn validity_mask_sliced_nonnull_fill() {
540 let sliced = SparseArray::try_new(
541 buffer![2u64, 5, 8].into_array(),
542 ConstantArray::new(
543 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
544 3,
545 )
546 .into_array(),
547 10,
548 Scalar::primitive(1.0f32, Nullability::Nullable),
549 )
550 .unwrap()
551 .slice(2..7)
552 .unwrap();
553
554 assert_eq!(
555 sliced.validity_mask().unwrap(),
556 Mask::from_iter(vec![false, true, true, false, true])
557 );
558 }
559
560 #[test]
561 pub fn scalar_at_sliced_twice() {
562 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
563 assert_eq!(
564 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
565 100
566 );
567
568 let sliced_twice = sliced_once.slice(1..6).unwrap();
569 assert_eq!(
570 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
571 200
572 );
573 }
574
575 #[test]
576 pub fn sparse_validity_mask() {
577 let array = sparse_array(nullable_fill());
578 assert_eq!(
579 array
580 .validity_mask()
581 .unwrap()
582 .to_bit_buffer()
583 .iter()
584 .collect_vec(),
585 [
586 false, false, true, false, false, true, false, false, true, false
587 ]
588 );
589 }
590
591 #[test]
592 fn sparse_validity_mask_non_null_fill() {
593 let array = sparse_array(non_nullable_fill());
594 assert!(array.validity_mask().unwrap().all_true());
595 }
596
597 #[test]
598 #[should_panic]
599 fn test_invalid_length() {
600 let values = buffer![15_u32, 135, 13531, 42].into_array();
601 let indices = buffer![10_u64, 11, 50, 100].into_array();
602
603 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
604 }
605
606 #[test]
607 fn test_valid_length() {
608 let values = buffer![15_u32, 135, 13531, 42].into_array();
609 let indices = buffer![10_u64, 11, 50, 100].into_array();
610
611 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
612 }
613
614 #[test]
615 fn encode_with_nulls() {
616 let original = PrimitiveArray::new(
617 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
618 Validity::from_iter(vec![
619 true, true, false, true, false, true, false, true, true, false, true, false,
620 ]),
621 );
622 let sparse = SparseArray::encode(&original.clone().into_array(), None)
623 .vortex_expect("SparseArray::encode should succeed for test data");
624 assert_eq!(
625 sparse.validity_mask().unwrap(),
626 Mask::from_iter(vec![
627 true, true, false, true, false, true, false, true, true, false, true, false,
628 ])
629 );
630 assert_arrays_eq!(sparse.to_primitive(), original);
631 }
632
633 #[test]
634 fn validity_mask_includes_null_values_when_fill_is_null() {
635 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
636 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
637 .into_array();
638 let array =
639 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
640 let actual = array.validity_mask().unwrap();
641 let expected = Mask::from_iter([
642 true, false, true, false, false, false, false, false, true, false,
643 ]);
644
645 assert_eq!(actual, expected);
646 }
647}