1use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28 type Array = SparseArray;
29 type Encoding = SparseEncoding;
30
31 type ArrayVTable = Self;
32 type CanonicalVTable = Self;
33 type OperationsVTable = Self;
34 type ValidityVTable = Self;
35 type VisitorVTable = Self;
36 type ComputeVTable = NotSupported;
37 type EncodeVTable = Self;
38 type SerdeVTable = Self;
39
40 fn id(_encoding: &Self::Encoding) -> EncodingId {
41 EncodingId::new_ref("vortex.sparse")
42 }
43
44 fn encoding(_array: &Self::Array) -> EncodingRef {
45 EncodingRef::new_ref(SparseEncoding.as_ref())
46 }
47}
48
49#[derive(Clone, Debug)]
50pub struct SparseArray {
51 patches: Patches,
52 fill_value: Scalar,
53 stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct SparseEncoding;
58
59impl SparseArray {
60 pub fn try_new(
61 indices: ArrayRef,
62 values: ArrayRef,
63 len: usize,
64 fill_value: Scalar,
65 ) -> VortexResult<Self> {
66 Self::try_new_with_offset(indices, values, len, 0, fill_value)
67 }
68
69 pub(crate) fn try_new_with_offset(
70 indices: ArrayRef,
71 values: ArrayRef,
72 len: usize,
73 indices_offset: usize,
74 fill_value: Scalar,
75 ) -> VortexResult<Self> {
76 if indices.len() != values.len() {
77 vortex_bail!(
78 "Mismatched indices {} and values {} length",
79 indices.len(),
80 values.len()
81 );
82 }
83
84 if !indices.is_empty() {
85 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
86
87 if last_index - indices_offset >= len {
88 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
89 }
90 }
91
92 let patches = Patches::new(len, indices_offset, indices, values);
93
94 Self::try_new_from_patches(patches, fill_value)
95 }
96
97 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
98 if fill_value.dtype() != patches.values().dtype() {
99 vortex_bail!(
100 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
101 fill_value,
102 patches.values().dtype(),
103 fill_value.dtype(),
104 );
105 }
106 Ok(Self {
107 patches,
108 fill_value,
109 stats_set: Default::default(),
110 })
111 }
112
113 #[inline]
114 pub fn patches(&self) -> &Patches {
115 &self.patches
116 }
117
118 #[inline]
119 pub fn resolved_patches(&self) -> VortexResult<Patches> {
120 let patches = self.patches();
121 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
122 let indices = sub_scalar(patches.indices(), indices_offset)?;
123 Ok(Patches::new(
124 patches.array_len(),
125 0,
126 indices,
127 patches.values().clone(),
128 ))
129 }
130
131 #[inline]
132 pub fn fill_scalar(&self) -> &Scalar {
133 &self.fill_value
134 }
135
136 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
140 if let Some(fill_value) = fill_value.as_ref() {
141 if array.dtype() != fill_value.dtype() {
142 vortex_bail!(
143 "Array and fill value types must match. got {} and {}",
144 array.dtype(),
145 fill_value.dtype()
146 )
147 }
148 }
149 let mask = array.validity_mask()?;
150
151 if mask.all_false() {
152 return Ok(
154 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
155 );
156 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
157 let non_null_values = filter(array, &mask)?;
159 let non_null_indices = match mask.indices() {
160 AllOr::All => {
161 unreachable!("Mask is mostly null")
163 }
164 AllOr::None => {
165 unreachable!("Mask is mostly null but not all null")
167 }
168 AllOr::Some(values) => {
169 let buffer: Buffer<u32> = values
170 .iter()
171 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
172 .collect();
173
174 buffer.into_array()
175 }
176 };
177
178 return Ok(SparseArray::try_new(
179 non_null_indices,
180 non_null_values,
181 array.len(),
182 Scalar::null(array.dtype().clone()),
183 )?
184 .into_array());
185 }
186
187 let fill = if let Some(fill) = fill_value {
188 fill
189 } else {
190 let (top_pvalue, _) = array
192 .to_primitive()?
193 .top_value()?
194 .vortex_expect("Non empty or all null array");
195
196 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
197 };
198
199 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
200 let non_top_mask = Mask::from_buffer(
201 fill_null(
202 &compare(array, &fill_array, Operator::NotEq)?,
203 &Scalar::bool(true, Nullability::NonNullable),
204 )?
205 .to_bool()?
206 .boolean_buffer()
207 .clone(),
208 );
209
210 let non_top_values = filter(array, &non_top_mask)?;
211
212 let indices: Buffer<u64> = match non_top_mask {
213 Mask::AllTrue(count) => {
214 (0u64..count as u64).collect()
216 }
217 Mask::AllFalse(_) => {
218 return Ok(fill_array);
220 }
221 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
222 };
223
224 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
225 .map(|a| a.into_array())
226 }
227}
228
229impl ArrayVTable<SparseVTable> for SparseVTable {
230 fn len(array: &SparseArray) -> usize {
231 array.patches.array_len()
232 }
233
234 fn dtype(array: &SparseArray) -> &DType {
235 array.fill_scalar().dtype()
236 }
237
238 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
239 array.stats_set.to_ref(array.as_ref())
240 }
241}
242
243impl ValidityVTable<SparseVTable> for SparseVTable {
244 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
245 Ok(match array.patches().get_patched(index)? {
246 None => array.fill_scalar().is_valid(),
247 Some(patch_value) => patch_value.is_valid(),
248 })
249 }
250
251 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
252 if array.fill_scalar().is_null() {
253 return Ok(array.patches().values().len() == array.len()
255 && array.patches().values().all_valid()?);
256 }
257
258 array.patches().values().all_valid()
259 }
260
261 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
262 if !array.fill_scalar().is_null() {
263 return Ok(array.patches().values().len() == array.len()
265 && array.patches().values().all_invalid()?);
266 }
267
268 array.patches().values().all_invalid()
269 }
270
271 #[allow(clippy::unnecessary_fallible_conversions)]
272 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
273 let fill_is_valid = array.fill_scalar().is_valid();
274 let values_validity = array.patches().values().validity_mask()?;
275 let len = array.len();
276
277 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
278 return Ok(Mask::AllTrue(len));
279 }
280 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
281 return Ok(Mask::AllFalse(len));
282 }
283
284 let mut is_valid_buffer = BooleanBufferBuilder::new(len);
286 is_valid_buffer.append_n(len, fill_is_valid);
287
288 let indices = array.patches().indices().to_primitive()?;
289 let index_offset = array.patches().offset();
290
291 match_each_integer_ptype!(indices.ptype(), |I| {
292 let indices = indices.as_slice::<I>();
293 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
294 });
295
296 Ok(Mask::from_buffer(is_valid_buffer.finish()))
297 }
298}
299
300fn patch_validity<I: NativePType>(
301 is_valid_buffer: &mut BooleanBufferBuilder,
302 indices: &[I],
303 index_offset: usize,
304 values_validity: Mask,
305) {
306 let indices = indices.iter().map(|index| {
307 let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
308 index - index_offset
309 });
310 match values_validity {
311 Mask::AllTrue(_) => {
312 for index in indices {
313 is_valid_buffer.set_bit(index, true);
314 }
315 }
316 Mask::AllFalse(_) => {
317 for index in indices {
318 is_valid_buffer.set_bit(index, false);
319 }
320 }
321 Mask::Values(mask_values) => {
322 let is_valid = mask_values.boolean_buffer().iter();
323 for (index, is_valid) in indices.zip_eq(is_valid) {
324 is_valid_buffer.set_bit(index, is_valid);
325 }
326 }
327 }
328}
329
330#[cfg(test)]
331mod test {
332 use itertools::Itertools;
333 use vortex_array::IntoArray;
334 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
335 use vortex_array::compute::cast;
336 use vortex_array::validity::Validity;
337 use vortex_buffer::buffer;
338 use vortex_dtype::{DType, Nullability, PType};
339 use vortex_error::{VortexError, VortexUnwrap};
340 use vortex_scalar::{PrimitiveScalar, Scalar};
341
342 use super::*;
343
344 fn nullable_fill() -> Scalar {
345 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
346 }
347
348 fn non_nullable_fill() -> Scalar {
349 Scalar::from(42i32)
350 }
351
352 fn sparse_array(fill_value: Scalar) -> ArrayRef {
353 let mut values = buffer![100i32, 200, 300].into_array();
355 values = cast(&values, fill_value.dtype()).unwrap();
356
357 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
358 .unwrap()
359 .into_array()
360 }
361
362 #[test]
363 pub fn test_scalar_at() {
364 let array = sparse_array(nullable_fill());
365
366 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
367 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
368 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
369
370 let error = array.scalar_at(10).err().unwrap();
371 let VortexError::OutOfBounds(i, start, stop, _) = error else {
372 unreachable!()
373 };
374 assert_eq!(i, 10);
375 assert_eq!(start, 0);
376 assert_eq!(stop, 10);
377 }
378
379 #[test]
380 pub fn test_scalar_at_again() {
381 let arr = SparseArray::try_new(
382 ConstantArray::new(10u32, 1).into_array(),
383 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
384 100,
385 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
386 )
387 .unwrap();
388
389 assert_eq!(
390 PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
391 .unwrap()
392 .typed_value::<u32>(),
393 Some(1234)
394 );
395 assert!(arr.scalar_at(0).unwrap().is_null());
396 assert!(arr.scalar_at(99).unwrap().is_null());
397 }
398
399 #[test]
400 pub fn scalar_at_sliced() {
401 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
402 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
403 let error = sliced.scalar_at(5).err().unwrap();
404 let VortexError::OutOfBounds(i, start, stop, _) = error else {
405 unreachable!()
406 };
407 assert_eq!(i, 5);
408 assert_eq!(start, 0);
409 assert_eq!(stop, 5);
410 }
411
412 #[test]
413 pub fn validity_mask_sliced_null_fill() {
414 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
415 assert_eq!(
416 sliced.validity_mask().unwrap(),
417 Mask::from_iter(vec![true, false, false, true, false])
418 );
419 }
420
421 #[test]
422 pub fn validity_mask_sliced_nonnull_fill() {
423 let sliced = SparseArray::try_new(
424 buffer![2u64, 5, 8].into_array(),
425 ConstantArray::new(
426 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
427 3,
428 )
429 .into_array(),
430 10,
431 Scalar::primitive(1.0f32, Nullability::Nullable),
432 )
433 .unwrap()
434 .slice(2, 7)
435 .unwrap();
436
437 assert_eq!(
438 sliced.validity_mask().unwrap(),
439 Mask::from_iter(vec![false, true, true, false, true])
440 );
441 }
442
443 #[test]
444 pub fn scalar_at_sliced_twice() {
445 let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
446 assert_eq!(
447 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
448 100
449 );
450 let error = sliced_once.scalar_at(7).err().unwrap();
451 let VortexError::OutOfBounds(i, start, stop, _) = error else {
452 unreachable!()
453 };
454 assert_eq!(i, 7);
455 assert_eq!(start, 0);
456 assert_eq!(stop, 7);
457
458 let sliced_twice = sliced_once.slice(1, 6).unwrap();
459 assert_eq!(
460 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
461 200
462 );
463 let error2 = sliced_twice.scalar_at(5).err().unwrap();
464 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
465 unreachable!()
466 };
467 assert_eq!(i, 5);
468 assert_eq!(start, 0);
469 assert_eq!(stop, 5);
470 }
471
472 #[test]
473 pub fn sparse_validity_mask() {
474 let array = sparse_array(nullable_fill());
475 assert_eq!(
476 array
477 .validity_mask()
478 .unwrap()
479 .to_boolean_buffer()
480 .iter()
481 .collect_vec(),
482 [
483 false, false, true, false, false, true, false, false, true, false
484 ]
485 );
486 }
487
488 #[test]
489 fn sparse_validity_mask_non_null_fill() {
490 let array = sparse_array(non_nullable_fill());
491 assert!(array.validity_mask().unwrap().all_true());
492 }
493
494 #[test]
495 #[should_panic]
496 fn test_invalid_length() {
497 let values = buffer![15_u32, 135, 13531, 42].into_array();
498 let indices = buffer![10_u64, 11, 50, 100].into_array();
499
500 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
501 }
502
503 #[test]
504 fn test_valid_length() {
505 let values = buffer![15_u32, 135, 13531, 42].into_array();
506 let indices = buffer![10_u64, 11, 50, 100].into_array();
507
508 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
509 }
510
511 #[test]
512 fn encode_with_nulls() {
513 let sparse = SparseArray::encode(
514 &PrimitiveArray::new(
515 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
516 Validity::from_iter(vec![
517 true, true, false, true, false, true, false, true, true, false, true, false,
518 ]),
519 )
520 .into_array(),
521 None,
522 )
523 .vortex_unwrap();
524 let canonical = sparse.to_primitive().vortex_unwrap();
525 assert_eq!(
526 sparse.validity_mask().unwrap(),
527 Mask::from_iter(vec![
528 true, true, false, true, false, true, false, true, true, false, true, false,
529 ])
530 );
531 assert_eq!(
532 canonical.as_slice::<i32>(),
533 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
534 );
535 }
536
537 #[test]
538 fn validity_mask_includes_null_values_when_fill_is_null() {
539 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
540 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
541 .into_array();
542 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
543 let actual = array.validity_mask().unwrap();
544 let expected = Mask::from_iter([
545 true, false, true, false, false, false, false, false, true, false,
546 ]);
547
548 assert_eq!(actual, expected);
549 }
550}