1use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28 type Array = SparseArray;
29 type Encoding = SparseEncoding;
30
31 type ArrayVTable = Self;
32 type CanonicalVTable = Self;
33 type OperationsVTable = Self;
34 type ValidityVTable = Self;
35 type VisitorVTable = Self;
36 type ComputeVTable = NotSupported;
37 type EncodeVTable = Self;
38 type SerdeVTable = Self;
39 type PipelineVTable = NotSupported;
40
41 fn id(_encoding: &Self::Encoding) -> EncodingId {
42 EncodingId::new_ref("vortex.sparse")
43 }
44
45 fn encoding(_array: &Self::Array) -> EncodingRef {
46 EncodingRef::new_ref(SparseEncoding.as_ref())
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct SparseArray {
52 patches: Patches,
53 fill_value: Scalar,
54 stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct SparseEncoding;
59
60impl SparseArray {
61 pub fn try_new(
62 indices: ArrayRef,
63 values: ArrayRef,
64 len: usize,
65 fill_value: Scalar,
66 ) -> VortexResult<Self> {
67 vortex_ensure!(
68 indices.len() == values.len(),
69 "Mismatched indices {} and values {} length",
70 indices.len(),
71 values.len()
72 );
73
74 vortex_ensure!(
75 indices.statistics().compute_is_strict_sorted() == Some(true),
76 "SparseArray: indices must be strict-sorted"
77 );
78
79 if !indices.is_empty() {
81 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
82
83 vortex_ensure!(
84 last_index < len,
85 "Array length was {len} but the last index is {last_index}"
86 );
87 }
88
89 let patches = Patches::new(len, 0, indices, values);
90
91 Ok(Self {
92 patches,
93 fill_value,
94 stats_set: Default::default(),
95 })
96 }
97
98 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
100 vortex_ensure!(
101 fill_value.dtype() == patches.values().dtype(),
102 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
103 fill_value,
104 patches.values().dtype(),
105 fill_value.dtype(),
106 );
107
108 Ok(Self {
109 patches,
110 fill_value,
111 stats_set: Default::default(),
112 })
113 }
114
115 pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
116 Self {
117 patches,
118 fill_value,
119 stats_set: Default::default(),
120 }
121 }
122
123 #[inline]
124 pub fn patches(&self) -> &Patches {
125 &self.patches
126 }
127
128 #[inline]
129 pub fn resolved_patches(&self) -> VortexResult<Patches> {
130 let patches = self.patches();
131 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
132 let indices = sub_scalar(patches.indices(), indices_offset)?;
133 Ok(Patches::new(
134 patches.array_len(),
135 0,
136 indices,
137 patches.values().clone(),
138 ))
139 }
140
141 #[inline]
142 pub fn fill_scalar(&self) -> &Scalar {
143 &self.fill_value
144 }
145
146 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
150 if let Some(fill_value) = fill_value.as_ref()
151 && array.dtype() != fill_value.dtype()
152 {
153 vortex_bail!(
154 "Array and fill value types must match. got {} and {}",
155 array.dtype(),
156 fill_value.dtype()
157 )
158 }
159 let mask = array.validity_mask()?;
160
161 if mask.all_false() {
162 return Ok(
164 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
165 );
166 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
167 let non_null_values = filter(array, &mask)?;
169 let non_null_indices = match mask.indices() {
170 AllOr::All => {
171 unreachable!("Mask is mostly null")
173 }
174 AllOr::None => {
175 unreachable!("Mask is mostly null but not all null")
177 }
178 AllOr::Some(values) => {
179 let buffer: Buffer<u32> = values
180 .iter()
181 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
182 .collect();
183
184 buffer.into_array()
185 }
186 };
187
188 return Ok(SparseArray::try_new(
189 non_null_indices,
190 non_null_values,
191 array.len(),
192 Scalar::null(array.dtype().clone()),
193 )?
194 .into_array());
195 }
196
197 let fill = if let Some(fill) = fill_value {
198 fill
199 } else {
200 let (top_pvalue, _) = array
202 .to_primitive()?
203 .top_value()?
204 .vortex_expect("Non empty or all null array");
205
206 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
207 };
208
209 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
210 let non_top_mask = Mask::from_buffer(
211 fill_null(
212 &compare(array, &fill_array, Operator::NotEq)?,
213 &Scalar::bool(true, Nullability::NonNullable),
214 )?
215 .to_bool()?
216 .boolean_buffer()
217 .clone(),
218 );
219
220 let non_top_values = filter(array, &non_top_mask)?;
221
222 let indices: Buffer<u64> = match non_top_mask {
223 Mask::AllTrue(count) => {
224 (0u64..count as u64).collect()
226 }
227 Mask::AllFalse(_) => {
228 return Ok(fill_array);
230 }
231 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
232 };
233
234 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
235 .map(|a| a.into_array())
236 }
237}
238
239impl ArrayVTable<SparseVTable> for SparseVTable {
240 fn len(array: &SparseArray) -> usize {
241 array.patches.array_len()
242 }
243
244 fn dtype(array: &SparseArray) -> &DType {
245 array.fill_scalar().dtype()
246 }
247
248 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
249 array.stats_set.to_ref(array.as_ref())
250 }
251}
252
253impl ValidityVTable<SparseVTable> for SparseVTable {
254 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
255 Ok(match array.patches().get_patched(index) {
256 None => array.fill_scalar().is_valid(),
257 Some(patch_value) => patch_value.is_valid(),
258 })
259 }
260
261 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
262 if array.fill_scalar().is_null() {
263 return Ok(array.patches().values().len() == array.len()
265 && array.patches().values().all_valid()?);
266 }
267
268 array.patches().values().all_valid()
269 }
270
271 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
272 if !array.fill_scalar().is_null() {
273 return Ok(array.patches().values().len() == array.len()
275 && array.patches().values().all_invalid()?);
276 }
277
278 array.patches().values().all_invalid()
279 }
280
281 #[allow(clippy::unnecessary_fallible_conversions)]
282 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
283 let fill_is_valid = array.fill_scalar().is_valid();
284 let values_validity = array.patches().values().validity_mask()?;
285 let len = array.len();
286
287 if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
288 return Ok(Mask::AllTrue(len));
289 }
290 if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
291 return Ok(Mask::AllFalse(len));
292 }
293
294 let mut is_valid_buffer = BooleanBufferBuilder::new(len);
296 is_valid_buffer.append_n(len, fill_is_valid);
297
298 let indices = array.patches().indices().to_primitive()?;
299 let index_offset = array.patches().offset();
300
301 match_each_integer_ptype!(indices.ptype(), |I| {
302 let indices = indices.as_slice::<I>();
303 patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
304 });
305
306 Ok(Mask::from_buffer(is_valid_buffer.finish()))
307 }
308}
309
310fn patch_validity<I: NativePType>(
311 is_valid_buffer: &mut BooleanBufferBuilder,
312 indices: &[I],
313 index_offset: usize,
314 values_validity: Mask,
315) {
316 let indices = indices.iter().map(|index| {
317 let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
318 index - index_offset
319 });
320 match values_validity {
321 Mask::AllTrue(_) => {
322 for index in indices {
323 is_valid_buffer.set_bit(index, true);
324 }
325 }
326 Mask::AllFalse(_) => {
327 for index in indices {
328 is_valid_buffer.set_bit(index, false);
329 }
330 }
331 Mask::Values(mask_values) => {
332 let is_valid = mask_values.boolean_buffer().iter();
333 for (index, is_valid) in indices.zip_eq(is_valid) {
334 is_valid_buffer.set_bit(index, is_valid);
335 }
336 }
337 }
338}
339
340#[cfg(test)]
341mod test {
342 use itertools::Itertools;
343 use vortex_array::IntoArray;
344 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
345 use vortex_array::compute::cast;
346 use vortex_array::validity::Validity;
347 use vortex_buffer::buffer;
348 use vortex_dtype::{DType, Nullability, PType};
349 use vortex_error::VortexUnwrap;
350 use vortex_scalar::{PrimitiveScalar, Scalar};
351
352 use super::*;
353
354 fn nullable_fill() -> Scalar {
355 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
356 }
357
358 fn non_nullable_fill() -> Scalar {
359 Scalar::from(42i32)
360 }
361
362 fn sparse_array(fill_value: Scalar) -> ArrayRef {
363 let mut values = buffer![100i32, 200, 300].into_array();
365 values = cast(&values, fill_value.dtype()).unwrap();
366
367 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
368 .unwrap()
369 .into_array()
370 }
371
372 #[test]
373 pub fn test_scalar_at() {
374 let array = sparse_array(nullable_fill());
375
376 assert_eq!(array.scalar_at(0), nullable_fill());
377 assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
378 assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
379 }
380
381 #[test]
382 #[should_panic(expected = "out of bounds")]
383 fn test_scalar_at_oob() {
384 let array = sparse_array(nullable_fill());
385 let _ = array.scalar_at(10);
386 }
387
388 #[test]
389 pub fn test_scalar_at_again() {
390 let arr = SparseArray::try_new(
391 ConstantArray::new(10u32, 1).into_array(),
392 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
393 100,
394 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
395 )
396 .unwrap();
397
398 assert_eq!(
399 PrimitiveScalar::try_from(&arr.scalar_at(10))
400 .unwrap()
401 .typed_value::<u32>(),
402 Some(1234)
403 );
404 assert!(arr.scalar_at(0).is_null());
405 assert!(arr.scalar_at(99).is_null());
406 }
407
408 #[test]
409 pub fn scalar_at_sliced() {
410 let sliced = sparse_array(nullable_fill()).slice(2, 7);
411 assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
412 }
413
414 #[test]
415 pub fn validity_mask_sliced_null_fill() {
416 let sliced = sparse_array(nullable_fill()).slice(2, 7);
417 assert_eq!(
418 sliced.validity_mask().unwrap(),
419 Mask::from_iter(vec![true, false, false, true, false])
420 );
421 }
422
423 #[test]
424 pub fn validity_mask_sliced_nonnull_fill() {
425 let sliced = SparseArray::try_new(
426 buffer![2u64, 5, 8].into_array(),
427 ConstantArray::new(
428 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
429 3,
430 )
431 .into_array(),
432 10,
433 Scalar::primitive(1.0f32, Nullability::Nullable),
434 )
435 .unwrap()
436 .slice(2, 7);
437
438 assert_eq!(
439 sliced.validity_mask().unwrap(),
440 Mask::from_iter(vec![false, true, true, false, true])
441 );
442 }
443
444 #[test]
445 pub fn scalar_at_sliced_twice() {
446 let sliced_once = sparse_array(nullable_fill()).slice(1, 8);
447 assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
448
449 let sliced_twice = sliced_once.slice(1, 6);
450 assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
451 }
452
453 #[test]
454 pub fn sparse_validity_mask() {
455 let array = sparse_array(nullable_fill());
456 assert_eq!(
457 array
458 .validity_mask()
459 .unwrap()
460 .to_boolean_buffer()
461 .iter()
462 .collect_vec(),
463 [
464 false, false, true, false, false, true, false, false, true, false
465 ]
466 );
467 }
468
469 #[test]
470 fn sparse_validity_mask_non_null_fill() {
471 let array = sparse_array(non_nullable_fill());
472 assert!(array.validity_mask().unwrap().all_true());
473 }
474
475 #[test]
476 #[should_panic]
477 fn test_invalid_length() {
478 let values = buffer![15_u32, 135, 13531, 42].into_array();
479 let indices = buffer![10_u64, 11, 50, 100].into_array();
480
481 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
482 }
483
484 #[test]
485 fn test_valid_length() {
486 let values = buffer![15_u32, 135, 13531, 42].into_array();
487 let indices = buffer![10_u64, 11, 50, 100].into_array();
488
489 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
490 }
491
492 #[test]
493 fn encode_with_nulls() {
494 let sparse = SparseArray::encode(
495 &PrimitiveArray::new(
496 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
497 Validity::from_iter(vec![
498 true, true, false, true, false, true, false, true, true, false, true, false,
499 ]),
500 )
501 .into_array(),
502 None,
503 )
504 .vortex_unwrap();
505 let canonical = sparse.to_primitive().vortex_unwrap();
506 assert_eq!(
507 sparse.validity_mask().unwrap(),
508 Mask::from_iter(vec![
509 true, true, false, true, false, true, false, true, true, false, true, false,
510 ])
511 );
512 assert_eq!(
513 canonical.as_slice::<i32>(),
514 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
515 );
516 }
517
518 #[test]
519 fn validity_mask_includes_null_values_when_fill_is_null() {
520 let indices = buffer![0u8, 2, 4, 6, 8].into_array();
521 let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
522 .into_array();
523 let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
524 let actual = array.validity_mask().unwrap();
525 let expected = Mask::from_iter([
526 true, false, true, false, false, false, false, false, true, false,
527 ]);
528
529 assert_eq!(actual, expected);
530 }
531}