1use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::Array;
10use vortex_array::ArrayBufferVisitor;
11use vortex_array::ArrayChildVisitor;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayRef;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::ConstantArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::builtins::ArrayBuiltins;
23use vortex_array::compute::filter;
24use vortex_array::dtype::DType;
25use vortex_array::dtype::Nullability;
26use vortex_array::patches::Patches;
27use vortex_array::patches::PatchesMetadata;
28use vortex_array::scalar::Scalar;
29use vortex_array::scalar::ScalarValue;
30use vortex_array::scalar_fn::fns::operators::Operator;
31use vortex_array::serde::ArrayChildren;
32use vortex_array::stats::ArrayStats;
33use vortex_array::stats::StatsSetRef;
34use vortex_array::validity::Validity;
35use vortex_array::vtable;
36use vortex_array::vtable::ArrayId;
37use vortex_array::vtable::BaseArrayVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_array::vtable::VisitorVTable;
41use vortex_buffer::Buffer;
42use vortex_buffer::ByteBufferMut;
43use vortex_error::VortexExpect as _;
44use vortex_error::VortexResult;
45use vortex_error::vortex_bail;
46use vortex_error::vortex_ensure;
47use vortex_mask::AllOr;
48use vortex_mask::Mask;
49use vortex_session::VortexSession;
50
51use crate::canonical::execute_sparse;
52use crate::rules::RULES;
53
54mod canonical;
55mod compute;
56mod kernel;
57mod ops;
58mod rules;
59mod slice;
60
61vtable!(Sparse);
62
63#[derive(Clone, prost::Message)]
64#[repr(C)]
65pub struct SparseMetadata {
66 #[prost(message, required, tag = "1")]
67 patches: PatchesMetadata,
68}
69
70impl VTable for SparseVTable {
71 type Array = SparseArray;
72
73 type Metadata = ProstMetadata<SparseMetadata>;
74
75 type ArrayVTable = Self;
76 type OperationsVTable = Self;
77 type ValidityVTable = Self;
78 type VisitorVTable = Self;
79
80 fn id(_array: &Self::Array) -> ArrayId {
81 Self::ID
82 }
83
84 fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
85 Ok(ProstMetadata(SparseMetadata {
86 patches: array.patches().to_metadata(array.len(), array.dtype())?,
87 }))
88 }
89
90 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
91 Ok(Some(metadata.0.encode_to_vec()))
92 }
93
94 fn deserialize(
95 bytes: &[u8],
96 _dtype: &DType,
97 _len: usize,
98 _buffers: &[BufferHandle],
99 _session: &VortexSession,
100 ) -> VortexResult<Self::Metadata> {
101 Ok(ProstMetadata(SparseMetadata::decode(bytes)?))
102 }
103
104 fn build(
105 dtype: &DType,
106 len: usize,
107 metadata: &Self::Metadata,
108 buffers: &[BufferHandle],
109 children: &dyn ArrayChildren,
110 ) -> VortexResult<SparseArray> {
111 if children.len() != 2 {
112 vortex_bail!(
113 "Expected 2 children for sparse encoding, found {}",
114 children.len()
115 )
116 }
117 vortex_ensure!(
118 metadata.0.patches.offset()? == 0,
119 "Patches must start at offset 0"
120 );
121
122 let patch_indices = children.get(
123 0,
124 &metadata.0.patches.indices_dtype()?,
125 metadata.0.patches.len()?,
126 )?;
127 let patch_values = children.get(1, dtype, metadata.0.patches.len()?)?;
128
129 if buffers.len() != 1 {
130 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
131 }
132
133 let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
134 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
135
136 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
137
138 SparseArray::try_new(patch_indices, patch_values, len, fill_value)
139 }
140
141 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
142 vortex_ensure!(
143 children.len() == 2,
144 "SparseArray expects 2 children, got {}",
145 children.len()
146 );
147
148 let mut children_iter = children.into_iter();
149 let patch_indices = children_iter.next().vortex_expect("patch_indices child");
150 let patch_values = children_iter.next().vortex_expect("patch_values child");
151
152 array.patches = Patches::new(
153 array.patches.array_len(),
154 array.patches.offset(),
155 patch_indices,
156 patch_values,
157 array.patches.chunk_offsets().clone(),
158 )?;
159
160 Ok(())
161 }
162
163 fn reduce_parent(
164 array: &Self::Array,
165 parent: &ArrayRef,
166 child_idx: usize,
167 ) -> VortexResult<Option<ArrayRef>> {
168 RULES.evaluate(array, parent, child_idx)
169 }
170
171 fn execute_parent(
172 array: &Self::Array,
173 parent: &ArrayRef,
174 child_idx: usize,
175 ctx: &mut ExecutionCtx,
176 ) -> VortexResult<Option<ArrayRef>> {
177 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
178 }
179
180 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
181 execute_sparse(array)
182 }
183}
184
185#[derive(Clone, Debug)]
186pub struct SparseArray {
187 patches: Patches,
188 fill_value: Scalar,
189 stats_set: ArrayStats,
190}
191
192#[derive(Debug)]
193pub struct SparseVTable;
194
195impl SparseVTable {
196 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
197}
198
199impl SparseArray {
200 pub fn try_new(
201 indices: ArrayRef,
202 values: ArrayRef,
203 len: usize,
204 fill_value: Scalar,
205 ) -> VortexResult<Self> {
206 vortex_ensure!(
207 indices.len() == values.len(),
208 "Mismatched indices {} and values {} length",
209 indices.len(),
210 values.len()
211 );
212
213 if indices.is_host() {
214 debug_assert_eq!(
215 indices.statistics().compute_is_strict_sorted(),
216 Some(true),
217 "SparseArray: indices must be strict-sorted"
218 );
219
220 if !indices.is_empty() {
222 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
223
224 vortex_ensure!(
225 last_index < len,
226 "Array length was {len} but the last index is {last_index}"
227 );
228 }
229 }
230
231 Ok(Self {
232 patches: Patches::new(len, 0, indices, values, None)?,
234 fill_value,
235 stats_set: Default::default(),
236 })
237 }
238
239 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
241 vortex_ensure!(
242 fill_value.dtype() == patches.values().dtype(),
243 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
244 fill_value,
245 patches.values().dtype(),
246 fill_value.dtype(),
247 );
248
249 Ok(Self {
250 patches,
251 fill_value,
252 stats_set: Default::default(),
253 })
254 }
255
256 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
257 Self {
258 patches,
259 fill_value,
260 stats_set: Default::default(),
261 }
262 }
263
264 #[inline]
265 pub fn patches(&self) -> &Patches {
266 &self.patches
267 }
268
269 #[inline]
270 pub fn resolved_patches(&self) -> VortexResult<Patches> {
271 let patches = self.patches();
272 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
273 let indices = patches.indices().to_array().binary(
274 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
275 Operator::Sub,
276 )?;
277
278 Patches::new(
279 patches.array_len(),
280 0,
281 indices,
282 patches.values().clone(),
283 None,
285 )
286 }
287
288 #[inline]
289 pub fn fill_scalar(&self) -> &Scalar {
290 &self.fill_value
291 }
292
293 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
297 if let Some(fill_value) = fill_value.as_ref()
298 && array.dtype() != fill_value.dtype()
299 {
300 vortex_bail!(
301 "Array and fill value types must match. got {} and {}",
302 array.dtype(),
303 fill_value.dtype()
304 )
305 }
306 let mask = array.validity_mask()?;
307
308 if mask.all_false() {
309 return Ok(
311 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
312 );
313 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
314 let non_null_values = filter(array, &mask)?;
316 let non_null_indices = match mask.indices() {
317 AllOr::All => {
318 unreachable!("Mask is mostly null")
320 }
321 AllOr::None => {
322 unreachable!("Mask is mostly null but not all null")
324 }
325 AllOr::Some(values) => {
326 let buffer: Buffer<u32> = values
327 .iter()
328 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
329 .collect();
330
331 buffer.into_array()
332 }
333 };
334
335 return Ok(SparseArray::try_new(
336 non_null_indices,
337 non_null_values,
338 array.len(),
339 Scalar::null(array.dtype().clone()),
340 )?
341 .into_array());
342 }
343
344 let fill = if let Some(fill) = fill_value {
345 fill
346 } else {
347 let (top_pvalue, _) = array
349 .to_primitive()
350 .top_value()?
351 .vortex_expect("Non empty or all null array");
352
353 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
354 };
355
356 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
357 let non_top_mask = Mask::from_buffer(
358 array
359 .to_array()
360 .binary(fill_array.clone(), Operator::NotEq)?
361 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
362 .to_bool()
363 .to_bit_buffer(),
364 );
365
366 let non_top_values = filter(array, &non_top_mask)?;
367
368 let indices: Buffer<u64> = match non_top_mask {
369 Mask::AllTrue(count) => {
370 (0u64..count as u64).collect()
372 }
373 Mask::AllFalse(_) => {
374 return Ok(fill_array);
376 }
377 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
378 };
379
380 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
381 .map(|a| a.into_array())
382 }
383}
384
385impl BaseArrayVTable<SparseVTable> for SparseVTable {
386 fn len(array: &SparseArray) -> usize {
387 array.patches.array_len()
388 }
389
390 fn dtype(array: &SparseArray) -> &DType {
391 array.fill_scalar().dtype()
392 }
393
394 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
395 array.stats_set.to_ref(array.as_ref())
396 }
397
398 fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
399 array.patches.array_hash(state, precision);
400 array.fill_value.hash(state);
401 }
402
403 fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
404 array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
405 }
406}
407
408impl ValidityVTable<SparseVTable> for SparseVTable {
409 fn validity(array: &SparseArray) -> VortexResult<Validity> {
410 let patches = unsafe {
411 Patches::new_unchecked(
412 array.patches.array_len(),
413 array.patches.offset(),
414 array.patches.indices().clone(),
415 array
416 .patches
417 .values()
418 .validity()?
419 .to_array(array.patches.values().len()),
420 array.patches.chunk_offsets().clone(),
421 array.patches.offset_within_chunk(),
422 )
423 };
424
425 Ok(Validity::Array(
426 unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
427 .into_array(),
428 ))
429 }
430}
431
432impl VisitorVTable<SparseVTable> for SparseVTable {
433 fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
434 let fill_value_buffer =
435 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
436 visitor.visit_buffer_handle("fill_value", &BufferHandle::new_host(fill_value_buffer));
437 }
438
439 fn nbuffers(_array: &SparseArray) -> usize {
440 1
441 }
442
443 fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
444 visitor.visit_patches(array.patches())
445 }
446
447 fn nchildren(array: &SparseArray) -> usize {
448 2 + array.patches().chunk_offsets().is_some() as usize
450 }
451}
452
453#[cfg(test)]
454mod test {
455 use itertools::Itertools;
456 use vortex_array::IntoArray;
457 use vortex_array::arrays::ConstantArray;
458 use vortex_array::arrays::PrimitiveArray;
459 use vortex_array::assert_arrays_eq;
460 use vortex_array::builtins::ArrayBuiltins;
461 use vortex_array::dtype::DType;
462 use vortex_array::dtype::Nullability;
463 use vortex_array::dtype::PType;
464 use vortex_array::scalar::Scalar;
465 use vortex_array::validity::Validity;
466 use vortex_buffer::buffer;
467 use vortex_error::VortexExpect;
468
469 use super::*;
470
471 fn nullable_fill() -> Scalar {
472 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
473 }
474
475 fn non_nullable_fill() -> Scalar {
476 Scalar::from(42i32)
477 }
478
479 fn sparse_array(fill_value: Scalar) -> ArrayRef {
480 let mut values = buffer![100i32, 200, 300].into_array();
482 values = values.cast(fill_value.dtype().clone()).unwrap();
483
484 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
485 .unwrap()
486 .into_array()
487 }
488
489 #[test]
490 pub fn test_scalar_at() {
491 let array = sparse_array(nullable_fill());
492
493 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
494 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
495 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
496 }
497
498 #[test]
499 #[should_panic(expected = "out of bounds")]
500 fn test_scalar_at_oob() {
501 let array = sparse_array(nullable_fill());
502 array.scalar_at(10).unwrap();
503 }
504
505 #[test]
506 pub fn test_scalar_at_again() {
507 let arr = SparseArray::try_new(
508 ConstantArray::new(10u32, 1).into_array(),
509 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
510 100,
511 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
512 )
513 .unwrap();
514
515 assert_eq!(
516 arr.scalar_at(10)
517 .unwrap()
518 .as_primitive()
519 .typed_value::<u32>(),
520 Some(1234)
521 );
522 assert!(arr.scalar_at(0).unwrap().is_null());
523 assert!(arr.scalar_at(99).unwrap().is_null());
524 }
525
526 #[test]
527 pub fn scalar_at_sliced() {
528 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
529 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
530 }
531
532 #[test]
533 pub fn validity_mask_sliced_null_fill() {
534 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
535 assert_eq!(
536 sliced.validity_mask().unwrap(),
537 Mask::from_iter(vec![true, false, false, true, false])
538 );
539 }
540
541 #[test]
542 pub fn validity_mask_sliced_nonnull_fill() {
543 let sliced = SparseArray::try_new(
544 buffer![2u64, 5, 8].into_array(),
545 ConstantArray::new(
546 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
547 3,
548 )
549 .into_array(),
550 10,
551 Scalar::primitive(1.0f32, Nullability::Nullable),
552 )
553 .unwrap()
554 .slice(2..7)
555 .unwrap();
556
557 assert_eq!(
558 sliced.validity_mask().unwrap(),
559 Mask::from_iter(vec![false, true, true, false, true])
560 );
561 }
562
563 #[test]
564 pub fn scalar_at_sliced_twice() {
565 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
566 assert_eq!(
567 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
568 100
569 );
570
571 let sliced_twice = sliced_once.slice(1..6).unwrap();
572 assert_eq!(
573 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
574 200
575 );
576 }
577
578 #[test]
579 pub fn sparse_validity_mask() {
580 let array = sparse_array(nullable_fill());
581 assert_eq!(
582 array
583 .validity_mask()
584 .unwrap()
585 .to_bit_buffer()
586 .iter()
587 .collect_vec(),
588 [
589 false, false, true, false, false, true, false, false, true, false
590 ]
591 );
592 }
593
594 #[test]
595 fn sparse_validity_mask_non_null_fill() {
596 let array = sparse_array(non_nullable_fill());
597 assert!(array.validity_mask().unwrap().all_true());
598 }
599
600 #[test]
601 #[should_panic]
602 fn test_invalid_length() {
603 let values = buffer![15_u32, 135, 13531, 42].into_array();
604 let indices = buffer![10_u64, 11, 50, 100].into_array();
605
606 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
607 }
608
609 #[test]
610 fn test_valid_length() {
611 let values = buffer![15_u32, 135, 13531, 42].into_array();
612 let indices = buffer![10_u64, 11, 50, 100].into_array();
613
614 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
615 }
616
617 #[test]
618 fn encode_with_nulls() {
619 let original = PrimitiveArray::new(
620 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
621 Validity::from_iter(vec![
622 true, true, false, true, false, true, false, true, true, false, true, false,
623 ]),
624 );
625 let sparse = SparseArray::encode(&original.clone().into_array(), None)
626 .vortex_expect("SparseArray::encode should succeed for test data");
627 assert_eq!(
628 sparse.validity_mask().unwrap(),
629 Mask::from_iter(vec![
630 true, true, false, true, false, true, false, true, true, false, true, false,
631 ])
632 );
633 assert_arrays_eq!(sparse.to_primitive(), original);
634 }
635
636 #[test]
637 fn validity_mask_includes_null_values_when_fill_is_null() {
638 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
639 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
640 .into_array();
641 let array =
642 SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
643 let actual = array.validity_mask().unwrap();
644 let expected = Mask::from_iter([
645 true, false, true, false, false, false, false, false, true, false,
646 ]);
647
648 assert_eq!(actual, expected);
649 }
650}