1use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, IntegerPType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28 type Array = SparseArray;
29 type Encoding = SparseEncoding;
30
31 type ArrayVTable = Self;
32 type CanonicalVTable = Self;
33 type OperationsVTable = Self;
34 type ValidityVTable = Self;
35 type VisitorVTable = Self;
36 type ComputeVTable = NotSupported;
37 type EncodeVTable = Self;
38 type SerdeVTable = Self;
39 type PipelineVTable = NotSupported;
40
41 fn id(_encoding: &Self::Encoding) -> EncodingId {
42 EncodingId::new_ref("vortex.sparse")
43 }
44
45 fn encoding(_array: &Self::Array) -> EncodingRef {
46 EncodingRef::new_ref(SparseEncoding.as_ref())
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct SparseArray {
52 patches: Patches,
53 fill_value: Scalar,
54 stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct SparseEncoding;
59
60impl SparseArray {
61 pub fn try_new(
62 indices: ArrayRef,
63 values: ArrayRef,
64 len: usize,
65 fill_value: Scalar,
66 ) -> VortexResult<Self> {
67 vortex_ensure!(
68 indices.len() == values.len(),
69 "Mismatched indices {} and values {} length",
70 indices.len(),
71 values.len()
72 );
73
74 vortex_ensure!(
75 indices.statistics().compute_is_strict_sorted() == Some(true),
76 "SparseArray: indices must be strict-sorted"
77 );
78
79 if !indices.is_empty() {
81 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
82
83 vortex_ensure!(
84 last_index < len,
85 "Array length was {len} but the last index is {last_index}"
86 );
87 }
88
89 Ok(Self {
90 patches: Patches::new(len, 0, indices, values, None),
92 fill_value,
93 stats_set: Default::default(),
94 })
95 }
96
97 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
99 vortex_ensure!(
100 fill_value.dtype() == patches.values().dtype(),
101 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
102 fill_value,
103 patches.values().dtype(),
104 fill_value.dtype(),
105 );
106
107 Ok(Self {
108 patches,
109 fill_value,
110 stats_set: Default::default(),
111 })
112 }
113
114 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
115 Self {
116 patches,
117 fill_value,
118 stats_set: Default::default(),
119 }
120 }
121
122 #[inline]
123 pub fn patches(&self) -> &Patches {
124 &self.patches
125 }
126
127 #[inline]
128 pub fn resolved_patches(&self) -> Patches {
129 let patches = self.patches();
130 let indices_offset = Scalar::from(patches.offset())
131 .cast(patches.indices().dtype())
132 .vortex_expect("Patches offset must cast to the indices dtype");
133 let indices = sub_scalar(patches.indices(), indices_offset)
134 .vortex_expect("must be able to subtract offset from indices");
135
136 Patches::new(
137 patches.array_len(),
138 0,
139 indices,
140 patches.values().clone(),
141 None,
143 )
144 }
145
146 #[inline]
147 pub fn fill_scalar(&self) -> &Scalar {
148 &self.fill_value
149 }
150
151 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
155 if let Some(fill_value) = fill_value.as_ref()
156 && array.dtype() != fill_value.dtype()
157 {
158 vortex_bail!(
159 "Array and fill value types must match. got {} and {}",
160 array.dtype(),
161 fill_value.dtype()
162 )
163 }
164 let mask = array.validity_mask();
165
166 if mask.all_false() {
167 return Ok(
169 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
170 );
171 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
172 let non_null_values = filter(array, &mask)?;
174 let non_null_indices = match mask.indices() {
175 AllOr::All => {
176 unreachable!("Mask is mostly null")
178 }
179 AllOr::None => {
180 unreachable!("Mask is mostly null but not all null")
182 }
183 AllOr::Some(values) => {
184 let buffer: Buffer<u32> = values
185 .iter()
186 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
187 .collect();
188
189 buffer.into_array()
190 }
191 };
192
193 return Ok(SparseArray::try_new(
194 non_null_indices,
195 non_null_values,
196 array.len(),
197 Scalar::null(array.dtype().clone()),
198 )?
199 .into_array());
200 }
201
202 let fill = if let Some(fill) = fill_value {
203 fill
204 } else {
205 let (top_pvalue, _) = array
207 .to_primitive()
208 .top_value()?
209 .vortex_expect("Non empty or all null array");
210
211 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
212 };
213
214 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
215 let non_top_mask = Mask::from_buffer(
216 fill_null(
217 &compare(array, &fill_array, Operator::NotEq)?,
218 &Scalar::bool(true, Nullability::NonNullable),
219 )?
220 .to_bool()
221 .boolean_buffer()
222 .clone(),
223 );
224
225 let non_top_values = filter(array, &non_top_mask)?;
226
227 let indices: Buffer<u64> = match non_top_mask {
228 Mask::AllTrue(count) => {
229 (0u64..count as u64).collect()
231 }
232 Mask::AllFalse(_) => {
233 return Ok(fill_array);
235 }
236 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
237 };
238
239 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
240 .map(|a| a.into_array())
241 }
242}
243
244impl ArrayVTable<SparseVTable> for SparseVTable {
245 fn len(array: &SparseArray) -> usize {
246 array.patches.array_len()
247 }
248
249 fn dtype(array: &SparseArray) -> &DType {
250 array.fill_scalar().dtype()
251 }
252
253 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
254 array.stats_set.to_ref(array.as_ref())
255 }
256}
257
258impl ValidityVTable<SparseVTable> for SparseVTable {
259 fn is_valid(array: &SparseArray, index: usize) -> bool {
260 match array.patches().get_patched(index) {
261 None => array.fill_scalar().is_valid(),
262 Some(patch_value) => patch_value.is_valid(),
263 }
264 }
265
266 fn all_valid(array: &SparseArray) -> bool {
267 if array.fill_scalar().is_null() {
268 return array.patches().values().len() == array.len()
270 && array.patches().values().all_valid();
271 }
272
273 array.patches().values().all_valid()
274 }
275
276 fn all_invalid(array: &SparseArray) -> bool {
277 if !array.fill_scalar().is_null() {
278 return array.patches().values().len() == array.len()
280 && array.patches().values().all_invalid();
281 }
282
283 array.patches().values().all_invalid()
284 }
285
286 #[allow(clippy::unnecessary_fallible_conversions)]
287 fn validity_mask(array: &SparseArray) -> Mask {
288 let fill_is_valid = array.fill_scalar().is_valid();
289 let values_validity = array.patches().values().validity_mask();
290 let len = array.len();
291
292 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
293 return Mask::AllTrue(len);
294 }
295 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
296 return Mask::AllFalse(len);
297 }
298
299 let mut is_valid_buffer = BooleanBufferBuilder::new(len);
301 is_valid_buffer.append_n(len, fill_is_valid);
302
303 let indices = array.patches().indices().to_primitive();
304 let index_offset = array.patches().offset();
305
306 match_each_integer_ptype!(indices.ptype(), |I| {
307 let indices = indices.as_slice::<I>();
308 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
309 });
310
311 Mask::from_buffer(is_valid_buffer.finish())
312 }
313}
314
315fn patch_validity<I: IntegerPType>(
316 is_valid_buffer: &mut BooleanBufferBuilder,
317 indices: &[I],
318 index_offset: usize,
319 values_validity: Mask,
320) {
321 let indices = indices.iter().map(|index| {
322 let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
323 index - index_offset
324 });
325 match values_validity {
326 Mask::AllTrue(_) => {
327 for index in indices {
328 is_valid_buffer.set_bit(index, true);
329 }
330 }
331 Mask::AllFalse(_) => {
332 for index in indices {
333 is_valid_buffer.set_bit(index, false);
334 }
335 }
336 Mask::Values(mask_values) => {
337 let is_valid = mask_values.boolean_buffer().iter();
338 for (index, is_valid) in indices.zip_eq(is_valid) {
339 is_valid_buffer.set_bit(index, is_valid);
340 }
341 }
342 }
343}
344
345#[cfg(test)]
346mod test {
347 use itertools::Itertools;
348 use vortex_array::IntoArray;
349 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
350 use vortex_array::compute::cast;
351 use vortex_array::validity::Validity;
352 use vortex_buffer::buffer;
353 use vortex_dtype::{DType, Nullability, PType};
354 use vortex_error::VortexUnwrap;
355 use vortex_scalar::{PrimitiveScalar, Scalar};
356
357 use super::*;
358
359 fn nullable_fill() -> Scalar {
360 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
361 }
362
363 fn non_nullable_fill() -> Scalar {
364 Scalar::from(42i32)
365 }
366
367 fn sparse_array(fill_value: Scalar) -> ArrayRef {
368 let mut values = buffer![100i32, 200, 300].into_array();
370 values = cast(&values, fill_value.dtype()).unwrap();
371
372 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
373 .unwrap()
374 .into_array()
375 }
376
377 #[test]
378 pub fn test_scalar_at() {
379 let array = sparse_array(nullable_fill());
380
381 assert_eq!(array.scalar_at(0), nullable_fill());
382 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
383 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
384 }
385
386 #[test]
387 #[should_panic(expected = "out of bounds")]
388 fn test_scalar_at_oob() {
389 let array = sparse_array(nullable_fill());
390 let _ = array.scalar_at(10);
391 }
392
393 #[test]
394 pub fn test_scalar_at_again() {
395 let arr = SparseArray::try_new(
396 ConstantArray::new(10u32, 1).into_array(),
397 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
398 100,
399 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
400 )
401 .unwrap();
402
403 assert_eq!(
404 PrimitiveScalar::try_from(&arr.scalar_at(10))
405 .unwrap()
406 .typed_value::<u32>(),
407 Some(1234)
408 );
409 assert!(arr.scalar_at(0).is_null());
410 assert!(arr.scalar_at(99).is_null());
411 }
412
413 #[test]
414 pub fn scalar_at_sliced() {
415 let sliced = sparse_array(nullable_fill()).slice(2..7);
416 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
417 }
418
419 #[test]
420 pub fn validity_mask_sliced_null_fill() {
421 let sliced = sparse_array(nullable_fill()).slice(2..7);
422 assert_eq!(
423 sliced.validity_mask(),
424 Mask::from_iter(vec![true, false, false, true, false])
425 );
426 }
427
428 #[test]
429 pub fn validity_mask_sliced_nonnull_fill() {
430 let sliced = SparseArray::try_new(
431 buffer![2u64, 5, 8].into_array(),
432 ConstantArray::new(
433 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
434 3,
435 )
436 .into_array(),
437 10,
438 Scalar::primitive(1.0f32, Nullability::Nullable),
439 )
440 .unwrap()
441 .slice(2..7);
442
443 assert_eq!(
444 sliced.validity_mask(),
445 Mask::from_iter(vec![false, true, true, false, true])
446 );
447 }
448
449 #[test]
450 pub fn scalar_at_sliced_twice() {
451 let sliced_once = sparse_array(nullable_fill()).slice(1..8);
452 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
453
454 let sliced_twice = sliced_once.slice(1..6);
455 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
456 }
457
458 #[test]
459 pub fn sparse_validity_mask() {
460 let array = sparse_array(nullable_fill());
461 assert_eq!(
462 array
463 .validity_mask()
464 .to_boolean_buffer()
465 .iter()
466 .collect_vec(),
467 [
468 false, false, true, false, false, true, false, false, true, false
469 ]
470 );
471 }
472
473 #[test]
474 fn sparse_validity_mask_non_null_fill() {
475 let array = sparse_array(non_nullable_fill());
476 assert!(array.validity_mask().all_true());
477 }
478
479 #[test]
480 #[should_panic]
481 fn test_invalid_length() {
482 let values = buffer![15_u32, 135, 13531, 42].into_array();
483 let indices = buffer![10_u64, 11, 50, 100].into_array();
484
485 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
486 }
487
488 #[test]
489 fn test_valid_length() {
490 let values = buffer![15_u32, 135, 13531, 42].into_array();
491 let indices = buffer![10_u64, 11, 50, 100].into_array();
492
493 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
494 }
495
496 #[test]
497 fn encode_with_nulls() {
498 let sparse = SparseArray::encode(
499 &PrimitiveArray::new(
500 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
501 Validity::from_iter(vec![
502 true, true, false, true, false, true, false, true, true, false, true, false,
503 ]),
504 )
505 .into_array(),
506 None,
507 )
508 .vortex_unwrap();
509 let canonical = sparse.to_primitive();
510 assert_eq!(
511 sparse.validity_mask(),
512 Mask::from_iter(vec![
513 true, true, false, true, false, true, false, true, true, false, true, false,
514 ])
515 );
516 assert_eq!(
517 canonical.as_slice::<i32>(),
518 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
519 );
520 }
521
522 #[test]
523 fn validity_mask_includes_null_values_when_fill_is_null() {
524 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
525 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
526 .into_array();
527 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
528 let actual = array.validity_mask();
529 let expected = Mask::from_iter([
530 true, false, true, false, false, false, false, false, true, false,
531 ]);
532
533 assert_eq!(actual, expected);
534 }
535}