1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use kernel::PARENT_KERNELS;
11use prost::Message as _;
12use vortex_array::Array;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayId;
16use vortex_array::ArrayParts;
17use vortex_array::ArrayRef;
18use vortex_array::ArrayView;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::ToCanonical;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::bool::BoolArrayExt;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::builtins::ArrayBuiltins;
28use vortex_array::dtype::DType;
29use vortex_array::dtype::Nullability;
30use vortex_array::patches::Patches;
31use vortex_array::patches::PatchesMetadata;
32use vortex_array::scalar::Scalar;
33use vortex_array::scalar::ScalarValue;
34use vortex_array::scalar_fn::fns::operators::Operator;
35use vortex_array::serde::ArrayChildren;
36use vortex_array::validity::Validity;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_buffer::Buffer;
40use vortex_buffer::ByteBufferMut;
41use vortex_error::VortexExpect as _;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_ensure_eq;
46use vortex_error::vortex_panic;
47use vortex_mask::AllOr;
48use vortex_mask::Mask;
49use vortex_session::VortexSession;
50
51use crate::canonical::execute_sparse;
52use crate::rules::RULES;
53
54mod canonical;
55mod compute;
56mod kernel;
57mod ops;
58mod rules;
59mod slice;
60
61pub type SparseArray = Array<Sparse>;
63
64#[derive(Clone, prost::Message)]
65#[repr(C)]
66pub struct SparseMetadata {
67 #[prost(message, required, tag = "1")]
68 patches: PatchesMetadata,
69}
70
71impl ArrayHash for SparseData {
72 fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
73 self.patches.array_hash(state, precision);
74 self.fill_value.hash(state);
75 }
76}
77
78impl ArrayEq for SparseData {
79 fn array_eq(&self, other: &Self, precision: Precision) -> bool {
80 self.patches.array_eq(&other.patches, precision) && self.fill_value == other.fill_value
81 }
82}
83
84impl VTable for Sparse {
85 type ArrayData = SparseData;
86
87 type OperationsVTable = Self;
88 type ValidityVTable = Self;
89
90 fn id(&self) -> ArrayId {
91 Self::ID
92 }
93
94 fn validate(
95 &self,
96 data: &Self::ArrayData,
97 dtype: &DType,
98 len: usize,
99 _slots: &[Option<ArrayRef>],
100 ) -> VortexResult<()> {
101 SparseData::validate(data.patches(), data.fill_scalar(), dtype, len)
102 }
103
104 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
105 1
106 }
107
108 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
109 match idx {
110 0 => {
111 let fill_value_buffer =
112 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
113 BufferHandle::new_host(fill_value_buffer)
114 }
115 _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
116 }
117 }
118
119 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
120 match idx {
121 0 => Some("fill_value".to_string()),
122 _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
123 }
124 }
125
126 fn serialize(
127 array: ArrayView<'_, Self>,
128 _session: &VortexSession,
129 ) -> VortexResult<Option<Vec<u8>>> {
130 let patches = array.patches().to_metadata(array.len(), array.dtype())?;
131 let metadata = SparseMetadata { patches };
132
133 Ok(Some(metadata.encode_to_vec()))
135 }
136
137 fn deserialize(
138 &self,
139 dtype: &DType,
140 len: usize,
141 metadata: &[u8],
142 buffers: &[BufferHandle],
143 children: &dyn ArrayChildren,
144 session: &VortexSession,
145 ) -> VortexResult<ArrayParts<Self>> {
146 let metadata = SparseMetadata::decode(metadata)?;
147
148 if buffers.len() != 1 {
151 vortex_bail!("Expected 1 buffer, got {}", buffers.len());
152 }
153 let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
154
155 let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
156 let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
157
158 vortex_ensure_eq!(
159 children.len(),
160 2,
161 "SparseArray expects 2 children for sparse encoding, found {}",
162 children.len()
163 );
164
165 let patch_indices = children.get(
166 0,
167 &metadata.patches.indices_dtype()?,
168 metadata.patches.len()?,
169 )?;
170 let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
171
172 let patches = Patches::new(
173 len,
174 metadata.patches.offset()?,
175 patch_indices,
176 patch_values,
177 None,
178 )?;
179 let slots = SparseData::make_slots(&patches);
180 let data = SparseData::try_new_from_patches(patches, fill_value)?;
181 Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
182 }
183
184 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
185 SLOT_NAMES[idx].to_string()
186 }
187
188 fn reduce_parent(
189 array: ArrayView<'_, Self>,
190 parent: &ArrayRef,
191 child_idx: usize,
192 ) -> VortexResult<Option<ArrayRef>> {
193 RULES.evaluate(array, parent, child_idx)
194 }
195
196 fn execute_parent(
197 array: ArrayView<'_, Self>,
198 parent: &ArrayRef,
199 child_idx: usize,
200 ctx: &mut ExecutionCtx,
201 ) -> VortexResult<Option<ArrayRef>> {
202 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
203 }
204
205 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
206 execute_sparse(&array, ctx).map(ExecutionResult::done)
207 }
208}
209
210pub(crate) const NUM_SLOTS: usize = 3;
211pub(crate) const SLOT_NAMES: [&str; NUM_SLOTS] =
212 ["patch_indices", "patch_values", "patch_chunk_offsets"];
213
214#[derive(Clone, Debug)]
215pub struct SparseData {
216 patches: Patches,
217 fill_value: Scalar,
218}
219
220impl Display for SparseData {
221 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
222 write!(f, "fill_value: {}", self.fill_value)
223 }
224}
225
226#[derive(Clone, Debug)]
227pub struct Sparse;
228
229impl Sparse {
230 pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
231
232 pub fn try_new(
234 indices: ArrayRef,
235 values: ArrayRef,
236 len: usize,
237 fill_value: Scalar,
238 ) -> VortexResult<SparseArray> {
239 let dtype = fill_value.dtype().clone();
240 let patches = Patches::new(len, 0, indices, values, None)?;
241 let slots = SparseData::make_slots(&patches);
242 let data = SparseData::try_new_from_patches(patches, fill_value)?;
243 Ok(unsafe {
244 Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
245 })
246 }
247
248 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<SparseArray> {
249 let dtype = fill_value.dtype().clone();
250 let len = patches.array_len();
251 let slots = SparseData::make_slots(&patches);
252 let data = SparseData::try_new_from_patches(patches, fill_value)?;
253 Ok(unsafe {
254 Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
255 })
256 }
257
258 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> SparseArray {
259 let dtype = fill_value.dtype().clone();
260 let len = patches.array_len();
261 let slots = SparseData::make_slots(&patches);
262 let data = unsafe { SparseData::new_unchecked(patches, fill_value) };
263 unsafe {
264 Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
265 }
266 }
267
268 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
270 SparseData::encode(array, fill_value)
271 }
272}
273
274impl SparseData {
275 fn normalize_patches_dtype(patches: Patches, fill_value: &Scalar) -> VortexResult<Patches> {
276 let fill_dtype = fill_value.dtype();
277 let values_dtype = patches.values().dtype();
278
279 vortex_ensure!(
280 values_dtype.eq_ignore_nullability(fill_dtype),
281 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
282 fill_value,
283 values_dtype,
284 fill_dtype,
285 );
286
287 if values_dtype == fill_dtype {
288 Ok(patches)
289 } else {
290 patches.cast_values(fill_dtype)
291 }
292 }
293
294 pub fn validate(
295 patches: &Patches,
296 fill_value: &Scalar,
297 dtype: &DType,
298 len: usize,
299 ) -> VortexResult<()> {
300 vortex_ensure!(
301 fill_value.dtype() == dtype,
302 "fill value dtype {} does not match array dtype {}",
303 fill_value.dtype(),
304 dtype,
305 );
306 vortex_ensure!(
307 patches.array_len() == len,
308 "patches length {} does not match array length {}",
309 patches.array_len(),
310 len
311 );
312 vortex_ensure!(
313 patches.values().dtype() == dtype,
314 "patch values dtype {} does not match array dtype {}",
315 patches.values().dtype(),
316 dtype,
317 );
318 Ok(())
319 }
320
321 fn make_slots(patches: &Patches) -> Vec<Option<ArrayRef>> {
322 vec![
323 Some(patches.indices().clone()),
324 Some(patches.values().clone()),
325 patches.chunk_offsets().clone(),
326 ]
327 }
328
329 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
331 let patches = Self::normalize_patches_dtype(patches, &fill_value)?;
332 Ok(Self {
333 patches,
334 fill_value,
335 })
336 }
337
338 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
339 Self {
340 patches,
341 fill_value,
342 }
343 }
344
345 #[inline]
347 pub fn len(&self) -> usize {
348 self.patches.array_len()
349 }
350
351 #[inline]
353 pub fn is_empty(&self) -> bool {
354 self.patches.array_len() == 0
355 }
356
357 #[inline]
359 pub fn dtype(&self) -> &DType {
360 self.fill_scalar().dtype()
361 }
362
363 #[inline]
364 pub fn patches(&self) -> &Patches {
365 &self.patches
366 }
367
368 #[inline]
369 pub fn resolved_patches(&self) -> VortexResult<Patches> {
370 let patches = self.patches();
371 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
372 let indices = patches.indices().binary(
373 ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
374 Operator::Sub,
375 )?;
376
377 Patches::new(
378 patches.array_len(),
379 0,
380 indices,
381 patches.values().clone(),
382 None,
384 )
385 }
386
387 #[inline]
388 pub fn fill_scalar(&self) -> &Scalar {
389 &self.fill_value
390 }
391
392 pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
396 if let Some(fill_value) = fill_value.as_ref()
397 && !array.dtype().eq_ignore_nullability(fill_value.dtype())
398 {
399 vortex_bail!(
400 "Array and fill value types must have the same base type. got {} and {}",
401 array.dtype(),
402 fill_value.dtype()
403 )
404 }
405 let mask = array.validity_mask()?;
406
407 if mask.all_false() {
408 return Ok(
410 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
411 );
412 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
413 let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
416 let non_null_indices = match mask.indices() {
417 AllOr::All => {
418 unreachable!("Mask is mostly null")
420 }
421 AllOr::None => {
422 unreachable!("Mask is mostly null but not all null")
424 }
425 AllOr::Some(values) => {
426 let buffer: Buffer<u32> = values
427 .iter()
428 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
429 .collect();
430
431 buffer.into_array()
432 }
433 };
434
435 return Sparse::try_new(
436 non_null_indices,
437 non_null_values,
438 array.len(),
439 Scalar::null(array.dtype().clone()),
440 )
441 .map(IntoArray::into_array);
442 }
443
444 let fill = if let Some(fill) = fill_value {
445 fill.cast(array.dtype())?
446 } else {
447 let (top_pvalue, _) = array
449 .to_primitive()
450 .top_value()?
451 .vortex_expect("Non empty or all null array");
452
453 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
454 };
455
456 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
457 let non_top_mask = Mask::from_buffer(
458 array
459 .binary(fill_array.clone(), Operator::NotEq)?
460 .fill_null(Scalar::bool(true, Nullability::NonNullable))?
461 .to_bool()
462 .to_bit_buffer(),
463 );
464
465 let non_top_values = array
466 .filter(non_top_mask.clone())?
467 .to_canonical()?
468 .into_array();
469
470 let indices: Buffer<u64> = match non_top_mask {
471 Mask::AllTrue(count) => {
472 (0u64..count as u64).collect()
474 }
475 Mask::AllFalse(_) => {
476 return Ok(fill_array);
478 }
479 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
480 };
481
482 Sparse::try_new(indices.into_array(), non_top_values, array.len(), fill)
483 .map(IntoArray::into_array)
484 }
485}
486
487impl ValidityVTable<Sparse> for Sparse {
488 fn validity(array: ArrayView<'_, Sparse>) -> VortexResult<Validity> {
489 let patches = unsafe {
490 Patches::new_unchecked(
491 array.patches.array_len(),
492 array.patches.offset(),
493 array.patches.indices().clone(),
494 array
495 .patches
496 .values()
497 .validity()?
498 .to_array(array.patches.values().len()),
499 array.patches.chunk_offsets().clone(),
500 array.patches.offset_within_chunk(),
501 )
502 };
503
504 Ok(Validity::Array(
505 unsafe { Sparse::new_unchecked(patches, array.fill_value.is_valid().into()) }
506 .into_array(),
507 ))
508 }
509}
510
511#[cfg(test)]
512mod test {
513 use itertools::Itertools;
514 use vortex_array::IntoArray;
515 use vortex_array::arrays::ConstantArray;
516 use vortex_array::arrays::PrimitiveArray;
517 use vortex_array::assert_arrays_eq;
518 use vortex_array::builtins::ArrayBuiltins;
519 use vortex_array::dtype::DType;
520 use vortex_array::dtype::Nullability;
521 use vortex_array::dtype::PType;
522 use vortex_array::scalar::Scalar;
523 use vortex_array::validity::Validity;
524 use vortex_buffer::buffer;
525 use vortex_error::VortexExpect;
526
527 use super::*;
528 use crate::Sparse;
529
530 fn nullable_fill() -> Scalar {
531 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
532 }
533
534 fn non_nullable_fill() -> Scalar {
535 Scalar::from(42i32)
536 }
537
538 fn sparse_array(fill_value: Scalar) -> ArrayRef {
539 let mut values = buffer![100i32, 200, 300].into_array();
541 values = values.cast(fill_value.dtype().clone()).unwrap();
542
543 Sparse::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
544 .unwrap()
545 .into_array()
546 }
547
548 #[test]
549 pub fn test_scalar_at() {
550 let array = sparse_array(nullable_fill());
551
552 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
553 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
554 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
555 }
556
557 #[test]
558 #[should_panic(expected = "out of bounds")]
559 fn test_scalar_at_oob() {
560 let array = sparse_array(nullable_fill());
561 array.scalar_at(10).unwrap();
562 }
563
564 #[test]
565 pub fn test_scalar_at_again() {
566 let arr = Sparse::try_new(
567 ConstantArray::new(10u32, 1).into_array(),
568 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
569 100,
570 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
571 )
572 .unwrap();
573
574 assert_eq!(
575 arr.scalar_at(10)
576 .unwrap()
577 .as_primitive()
578 .typed_value::<u32>(),
579 Some(1234)
580 );
581 assert!(arr.scalar_at(0).unwrap().is_null());
582 assert!(arr.scalar_at(99).unwrap().is_null());
583 }
584
585 #[test]
586 pub fn scalar_at_sliced() {
587 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
588 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
589 }
590
591 #[test]
592 pub fn validity_mask_sliced_null_fill() {
593 let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
594 assert_eq!(
595 sliced.validity_mask().unwrap(),
596 Mask::from_iter(vec![true, false, false, true, false])
597 );
598 }
599
600 #[test]
601 pub fn validity_mask_sliced_nonnull_fill() {
602 let sliced = Sparse::try_new(
603 buffer![2u64, 5, 8].into_array(),
604 ConstantArray::new(
605 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
606 3,
607 )
608 .into_array(),
609 10,
610 Scalar::primitive(1.0f32, Nullability::Nullable),
611 )
612 .unwrap()
613 .slice(2..7)
614 .unwrap();
615
616 assert_eq!(
617 sliced.validity_mask().unwrap(),
618 Mask::from_iter(vec![false, true, true, false, true])
619 );
620 }
621
622 #[test]
623 pub fn scalar_at_sliced_twice() {
624 let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
625 assert_eq!(
626 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
627 100
628 );
629
630 let sliced_twice = sliced_once.slice(1..6).unwrap();
631 assert_eq!(
632 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
633 200
634 );
635 }
636
637 #[test]
638 pub fn sparse_validity_mask() {
639 let array = sparse_array(nullable_fill());
640 assert_eq!(
641 array
642 .validity_mask()
643 .unwrap()
644 .to_bit_buffer()
645 .iter()
646 .collect_vec(),
647 [
648 false, false, true, false, false, true, false, false, true, false
649 ]
650 );
651 }
652
653 #[test]
654 fn sparse_validity_mask_non_null_fill() {
655 let array = sparse_array(non_nullable_fill());
656 assert!(array.validity_mask().unwrap().all_true());
657 }
658
659 #[test]
660 #[should_panic]
661 fn test_invalid_length() {
662 let values = buffer![15_u32, 135, 13531, 42].into_array();
663 let indices = buffer![10_u64, 11, 50, 100].into_array();
664
665 Sparse::try_new(indices, values, 100, 0_u32.into()).unwrap();
666 }
667
668 #[test]
669 fn test_valid_length() {
670 let values = buffer![15_u32, 135, 13531, 42].into_array();
671 let indices = buffer![10_u64, 11, 50, 100].into_array();
672
673 Sparse::try_new(indices, values, 101, 0_u32.into()).unwrap();
674 }
675
676 #[test]
677 fn encode_with_nulls() {
678 let original = PrimitiveArray::new(
679 buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
680 Validity::from_iter(vec![
681 true, true, false, true, false, true, false, true, true, false, true, false,
682 ]),
683 );
684 let sparse = Sparse::encode(&original.clone().into_array(), None)
685 .vortex_expect("Sparse::encode should succeed for test data");
686 assert_eq!(
687 sparse.validity_mask().unwrap(),
688 Mask::from_iter(vec![
689 true, true, false, true, false, true, false, true, true, false, true, false,
690 ])
691 );
692 assert_arrays_eq!(sparse.to_primitive(), original);
693 }
694
695 #[test]
696 fn validity_mask_includes_null_values_when_fill_is_null() {
697 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
698 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
699 .into_array();
700 let array = Sparse::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
701 let actual = array.validity_mask().unwrap();
702 let expected = Mask::from_iter([
703 true, false, true, false, false, false, false, false, true, false,
704 ]);
705
706 assert_eq!(actual, expected);
707 }
708}