1use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28 type Array = SparseArray;
29 type Encoding = SparseEncoding;
30
31 type ArrayVTable = Self;
32 type CanonicalVTable = Self;
33 type OperationsVTable = Self;
34 type ValidityVTable = Self;
35 type VisitorVTable = Self;
36 type ComputeVTable = NotSupported;
37 type EncodeVTable = Self;
38 type SerdeVTable = Self;
39 type PipelineVTable = NotSupported;
40
41 fn id(_encoding: &Self::Encoding) -> EncodingId {
42 EncodingId::new_ref("vortex.sparse")
43 }
44
45 fn encoding(_array: &Self::Array) -> EncodingRef {
46 EncodingRef::new_ref(SparseEncoding.as_ref())
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct SparseArray {
52 patches: Patches,
53 fill_value: Scalar,
54 stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct SparseEncoding;
59
60impl SparseArray {
61 pub fn try_new(
62 indices: ArrayRef,
63 values: ArrayRef,
64 len: usize,
65 fill_value: Scalar,
66 ) -> VortexResult<Self> {
67 vortex_ensure!(
68 indices.len() == values.len(),
69 "Mismatched indices {} and values {} length",
70 indices.len(),
71 values.len()
72 );
73
74 vortex_ensure!(
75 indices.statistics().compute_is_strict_sorted() == Some(true),
76 "SparseArray: indices must be strict-sorted"
77 );
78
79 if !indices.is_empty() {
81 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
82
83 vortex_ensure!(
84 last_index < len,
85 "Array length was {len} but the last index is {last_index}"
86 );
87 }
88
89 let patches = Patches::new(len, 0, indices, values);
90
91 Ok(Self {
92 patches,
93 fill_value,
94 stats_set: Default::default(),
95 })
96 }
97
98 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
100 vortex_ensure!(
101 fill_value.dtype() == patches.values().dtype(),
102 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
103 fill_value,
104 patches.values().dtype(),
105 fill_value.dtype(),
106 );
107
108 Ok(Self {
109 patches,
110 fill_value,
111 stats_set: Default::default(),
112 })
113 }
114
115 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
116 Self {
117 patches,
118 fill_value,
119 stats_set: Default::default(),
120 }
121 }
122
123 #[inline]
124 pub fn patches(&self) -> &Patches {
125 &self.patches
126 }
127
128 #[inline]
129 pub fn resolved_patches(&self) -> Patches {
130 let patches = self.patches();
131 let indices_offset = Scalar::from(patches.offset())
132 .cast(patches.indices().dtype())
133 .vortex_expect("Patches offset must cast to the indices dtype");
134 let indices = sub_scalar(patches.indices(), indices_offset)
135 .vortex_expect("must be able to subtract offset from indices");
136 Patches::new(patches.array_len(), 0, indices, patches.values().clone())
137 }
138
139 #[inline]
140 pub fn fill_scalar(&self) -> &Scalar {
141 &self.fill_value
142 }
143
144 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
148 if let Some(fill_value) = fill_value.as_ref()
149 && array.dtype() != fill_value.dtype()
150 {
151 vortex_bail!(
152 "Array and fill value types must match. got {} and {}",
153 array.dtype(),
154 fill_value.dtype()
155 )
156 }
157 let mask = array.validity_mask();
158
159 if mask.all_false() {
160 return Ok(
162 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
163 );
164 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
165 let non_null_values = filter(array, &mask)?;
167 let non_null_indices = match mask.indices() {
168 AllOr::All => {
169 unreachable!("Mask is mostly null")
171 }
172 AllOr::None => {
173 unreachable!("Mask is mostly null but not all null")
175 }
176 AllOr::Some(values) => {
177 let buffer: Buffer<u32> = values
178 .iter()
179 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
180 .collect();
181
182 buffer.into_array()
183 }
184 };
185
186 return Ok(SparseArray::try_new(
187 non_null_indices,
188 non_null_values,
189 array.len(),
190 Scalar::null(array.dtype().clone()),
191 )?
192 .into_array());
193 }
194
195 let fill = if let Some(fill) = fill_value {
196 fill
197 } else {
198 let (top_pvalue, _) = array
200 .to_primitive()
201 .top_value()?
202 .vortex_expect("Non empty or all null array");
203
204 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
205 };
206
207 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
208 let non_top_mask = Mask::from_buffer(
209 fill_null(
210 &compare(array, &fill_array, Operator::NotEq)?,
211 &Scalar::bool(true, Nullability::NonNullable),
212 )?
213 .to_bool()
214 .boolean_buffer()
215 .clone(),
216 );
217
218 let non_top_values = filter(array, &non_top_mask)?;
219
220 let indices: Buffer<u64> = match non_top_mask {
221 Mask::AllTrue(count) => {
222 (0u64..count as u64).collect()
224 }
225 Mask::AllFalse(_) => {
226 return Ok(fill_array);
228 }
229 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
230 };
231
232 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
233 .map(|a| a.into_array())
234 }
235}
236
237impl ArrayVTable<SparseVTable> for SparseVTable {
238 fn len(array: &SparseArray) -> usize {
239 array.patches.array_len()
240 }
241
242 fn dtype(array: &SparseArray) -> &DType {
243 array.fill_scalar().dtype()
244 }
245
246 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
247 array.stats_set.to_ref(array.as_ref())
248 }
249}
250
251impl ValidityVTable<SparseVTable> for SparseVTable {
252 fn is_valid(array: &SparseArray, index: usize) -> bool {
253 match array.patches().get_patched(index) {
254 None => array.fill_scalar().is_valid(),
255 Some(patch_value) => patch_value.is_valid(),
256 }
257 }
258
259 fn all_valid(array: &SparseArray) -> bool {
260 if array.fill_scalar().is_null() {
261 return array.patches().values().len() == array.len()
263 && array.patches().values().all_valid();
264 }
265
266 array.patches().values().all_valid()
267 }
268
269 fn all_invalid(array: &SparseArray) -> bool {
270 if !array.fill_scalar().is_null() {
271 return array.patches().values().len() == array.len()
273 && array.patches().values().all_invalid();
274 }
275
276 array.patches().values().all_invalid()
277 }
278
279 #[allow(clippy::unnecessary_fallible_conversions)]
280 fn validity_mask(array: &SparseArray) -> Mask {
281 let fill_is_valid = array.fill_scalar().is_valid();
282 let values_validity = array.patches().values().validity_mask();
283 let len = array.len();
284
285 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
286 return Mask::AllTrue(len);
287 }
288 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
289 return Mask::AllFalse(len);
290 }
291
292 let mut is_valid_buffer = BooleanBufferBuilder::new(len);
294 is_valid_buffer.append_n(len, fill_is_valid);
295
296 let indices = array.patches().indices().to_primitive();
297 let index_offset = array.patches().offset();
298
299 match_each_integer_ptype!(indices.ptype(), |I| {
300 let indices = indices.as_slice::<I>();
301 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
302 });
303
304 Mask::from_buffer(is_valid_buffer.finish())
305 }
306}
307
308fn patch_validity<I: NativePType>(
309 is_valid_buffer: &mut BooleanBufferBuilder,
310 indices: &[I],
311 index_offset: usize,
312 values_validity: Mask,
313) {
314 let indices = indices.iter().map(|index| {
315 let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
316 index - index_offset
317 });
318 match values_validity {
319 Mask::AllTrue(_) => {
320 for index in indices {
321 is_valid_buffer.set_bit(index, true);
322 }
323 }
324 Mask::AllFalse(_) => {
325 for index in indices {
326 is_valid_buffer.set_bit(index, false);
327 }
328 }
329 Mask::Values(mask_values) => {
330 let is_valid = mask_values.boolean_buffer().iter();
331 for (index, is_valid) in indices.zip_eq(is_valid) {
332 is_valid_buffer.set_bit(index, is_valid);
333 }
334 }
335 }
336}
337
338#[cfg(test)]
339mod test {
340 use itertools::Itertools;
341 use vortex_array::IntoArray;
342 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
343 use vortex_array::compute::cast;
344 use vortex_array::validity::Validity;
345 use vortex_buffer::buffer;
346 use vortex_dtype::{DType, Nullability, PType};
347 use vortex_error::VortexUnwrap;
348 use vortex_scalar::{PrimitiveScalar, Scalar};
349
350 use super::*;
351
352 fn nullable_fill() -> Scalar {
353 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
354 }
355
356 fn non_nullable_fill() -> Scalar {
357 Scalar::from(42i32)
358 }
359
360 fn sparse_array(fill_value: Scalar) -> ArrayRef {
361 let mut values = buffer![100i32, 200, 300].into_array();
363 values = cast(&values, fill_value.dtype()).unwrap();
364
365 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
366 .unwrap()
367 .into_array()
368 }
369
370 #[test]
371 pub fn test_scalar_at() {
372 let array = sparse_array(nullable_fill());
373
374 assert_eq!(array.scalar_at(0), nullable_fill());
375 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
376 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
377 }
378
379 #[test]
380 #[should_panic(expected = "out of bounds")]
381 fn test_scalar_at_oob() {
382 let array = sparse_array(nullable_fill());
383 let _ = array.scalar_at(10);
384 }
385
386 #[test]
387 pub fn test_scalar_at_again() {
388 let arr = SparseArray::try_new(
389 ConstantArray::new(10u32, 1).into_array(),
390 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
391 100,
392 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
393 )
394 .unwrap();
395
396 assert_eq!(
397 PrimitiveScalar::try_from(&arr.scalar_at(10))
398 .unwrap()
399 .typed_value::<u32>(),
400 Some(1234)
401 );
402 assert!(arr.scalar_at(0).is_null());
403 assert!(arr.scalar_at(99).is_null());
404 }
405
406 #[test]
407 pub fn scalar_at_sliced() {
408 let sliced = sparse_array(nullable_fill()).slice(2..7);
409 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
410 }
411
412 #[test]
413 pub fn validity_mask_sliced_null_fill() {
414 let sliced = sparse_array(nullable_fill()).slice(2..7);
415 assert_eq!(
416 sliced.validity_mask(),
417 Mask::from_iter(vec![true, false, false, true, false])
418 );
419 }
420
421 #[test]
422 pub fn validity_mask_sliced_nonnull_fill() {
423 let sliced = SparseArray::try_new(
424 buffer![2u64, 5, 8].into_array(),
425 ConstantArray::new(
426 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
427 3,
428 )
429 .into_array(),
430 10,
431 Scalar::primitive(1.0f32, Nullability::Nullable),
432 )
433 .unwrap()
434 .slice(2..7);
435
436 assert_eq!(
437 sliced.validity_mask(),
438 Mask::from_iter(vec![false, true, true, false, true])
439 );
440 }
441
442 #[test]
443 pub fn scalar_at_sliced_twice() {
444 let sliced_once = sparse_array(nullable_fill()).slice(1..8);
445 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
446
447 let sliced_twice = sliced_once.slice(1..6);
448 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
449 }
450
451 #[test]
452 pub fn sparse_validity_mask() {
453 let array = sparse_array(nullable_fill());
454 assert_eq!(
455 array
456 .validity_mask()
457 .to_boolean_buffer()
458 .iter()
459 .collect_vec(),
460 [
461 false, false, true, false, false, true, false, false, true, false
462 ]
463 );
464 }
465
466 #[test]
467 fn sparse_validity_mask_non_null_fill() {
468 let array = sparse_array(non_nullable_fill());
469 assert!(array.validity_mask().all_true());
470 }
471
472 #[test]
473 #[should_panic]
474 fn test_invalid_length() {
475 let values = buffer![15_u32, 135, 13531, 42].into_array();
476 let indices = buffer![10_u64, 11, 50, 100].into_array();
477
478 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
479 }
480
481 #[test]
482 fn test_valid_length() {
483 let values = buffer![15_u32, 135, 13531, 42].into_array();
484 let indices = buffer![10_u64, 11, 50, 100].into_array();
485
486 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
487 }
488
489 #[test]
490 fn encode_with_nulls() {
491 let sparse = SparseArray::encode(
492 &PrimitiveArray::new(
493 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
494 Validity::from_iter(vec![
495 true, true, false, true, false, true, false, true, true, false, true, false,
496 ]),
497 )
498 .into_array(),
499 None,
500 )
501 .vortex_unwrap();
502 let canonical = sparse.to_primitive();
503 assert_eq!(
504 sparse.validity_mask(),
505 Mask::from_iter(vec![
506 true, true, false, true, false, true, false, true, true, false, true, false,
507 ])
508 );
509 assert_eq!(
510 canonical.as_slice::<i32>(),
511 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
512 );
513 }
514
515 #[test]
516 fn validity_mask_includes_null_values_when_fill_is_null() {
517 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
518 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
519 .into_array();
520 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
521 let actual = array.validity_mask();
522 let expected = Mask::from_iter([
523 true, false, true, false, false, false, false, false, true, false,
524 ]);
525
526 assert_eq!(actual, expected);
527 }
528}