1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
8use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15mod canonical;
16mod compute;
17mod ops;
18mod serde;
19
20vtable!(Sparse);
21
22impl VTable for SparseVTable {
23 type Array = SparseArray;
24 type Encoding = SparseEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = Self;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = Self;
33 type SerdeVTable = Self;
34
35 fn id(_encoding: &Self::Encoding) -> EncodingId {
36 EncodingId::new_ref("vortex.sparse")
37 }
38
39 fn encoding(_array: &Self::Array) -> EncodingRef {
40 EncodingRef::new_ref(SparseEncoding.as_ref())
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct SparseArray {
46 patches: Patches,
47 fill_value: Scalar,
48 stats_set: ArrayStats,
49}
50
51#[derive(Clone, Debug)]
52pub struct SparseEncoding;
53
54impl SparseArray {
55 pub fn try_new(
56 indices: ArrayRef,
57 values: ArrayRef,
58 len: usize,
59 fill_value: Scalar,
60 ) -> VortexResult<Self> {
61 Self::try_new_with_offset(indices, values, len, 0, fill_value)
62 }
63
64 pub(crate) fn try_new_with_offset(
65 indices: ArrayRef,
66 values: ArrayRef,
67 len: usize,
68 indices_offset: usize,
69 fill_value: Scalar,
70 ) -> VortexResult<Self> {
71 if indices.len() != values.len() {
72 vortex_bail!(
73 "Mismatched indices {} and values {} length",
74 indices.len(),
75 values.len()
76 );
77 }
78
79 if !indices.is_empty() {
80 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
81
82 if last_index - indices_offset >= len {
83 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
84 }
85 }
86
87 let patches = Patches::new(len, indices_offset, indices, values);
88
89 Self::try_new_from_patches(patches, fill_value)
90 }
91
92 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
93 if fill_value.dtype() != patches.values().dtype() {
94 vortex_bail!(
95 "fill value, {:?}, should be instance of values dtype, {} but was {}.",
96 fill_value,
97 patches.values().dtype(),
98 fill_value.dtype(),
99 );
100 }
101 Ok(Self {
102 patches,
103 fill_value,
104 stats_set: Default::default(),
105 })
106 }
107
108 #[inline]
109 pub fn patches(&self) -> &Patches {
110 &self.patches
111 }
112
113 #[inline]
114 pub fn resolved_patches(&self) -> VortexResult<Patches> {
115 let patches = self.patches();
116 let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
117 let indices = sub_scalar(patches.indices(), indices_offset)?;
118 Ok(Patches::new(
119 patches.array_len(),
120 0,
121 indices,
122 patches.values().clone(),
123 ))
124 }
125
126 #[inline]
127 pub fn fill_scalar(&self) -> &Scalar {
128 &self.fill_value
129 }
130
131 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
135 if let Some(fill_value) = fill_value.as_ref() {
136 if array.dtype() != fill_value.dtype() {
137 vortex_bail!(
138 "Array and fill value types must match. got {} and {}",
139 array.dtype(),
140 fill_value.dtype()
141 )
142 }
143 }
144 let mask = array.validity_mask()?;
145
146 if mask.all_false() {
147 return Ok(
149 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
150 );
151 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
152 let non_null_values = filter(array, &mask)?;
154 let non_null_indices = match mask.indices() {
155 AllOr::All => {
156 unreachable!("Mask is mostly null")
158 }
159 AllOr::None => {
160 unreachable!("Mask is mostly null but not all null")
162 }
163 AllOr::Some(values) => {
164 let buffer: Buffer<u32> = values
165 .iter()
166 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
167 .collect();
168
169 buffer.into_array()
170 }
171 };
172
173 return Ok(SparseArray::try_new(
174 non_null_indices,
175 non_null_values,
176 array.len(),
177 Scalar::null(array.dtype().clone()),
178 )?
179 .into_array());
180 }
181
182 let fill = if let Some(fill) = fill_value {
183 fill
184 } else {
185 let (top_pvalue, _) = array
187 .to_primitive()?
188 .top_value()?
189 .vortex_expect("Non empty or all null array");
190
191 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
192 };
193
194 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
195 let non_top_mask = Mask::from_buffer(
196 fill_null(
197 &compare(array, &fill_array, Operator::NotEq)?,
198 &Scalar::bool(true, Nullability::NonNullable),
199 )?
200 .to_bool()?
201 .boolean_buffer()
202 .clone(),
203 );
204
205 let non_top_values = filter(array, &non_top_mask)?;
206
207 let indices: Buffer<u64> = match non_top_mask {
208 Mask::AllTrue(count) => {
209 (0u64..count as u64).collect()
211 }
212 Mask::AllFalse(_) => {
213 return Ok(fill_array);
215 }
216 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
217 };
218
219 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
220 .map(|a| a.into_array())
221 }
222}
223
224impl ArrayVTable<SparseVTable> for SparseVTable {
225 fn len(array: &SparseArray) -> usize {
226 array.patches.array_len()
227 }
228
229 fn dtype(array: &SparseArray) -> &DType {
230 array.fill_scalar().dtype()
231 }
232
233 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
234 array.stats_set.to_ref(array.as_ref())
235 }
236}
237
238impl ValidityVTable<SparseVTable> for SparseVTable {
239 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
240 Ok(match array.patches().get_patched(index)? {
241 None => array.fill_scalar().is_valid(),
242 Some(patch_value) => patch_value.is_valid(),
243 })
244 }
245
246 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
247 if array.fill_scalar().is_null() {
248 return Ok(array.patches().values().len() == array.len()
250 && array.patches().values().all_valid()?);
251 }
252
253 array.patches().values().all_valid()
254 }
255
256 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
257 if !array.fill_scalar().is_null() {
258 return Ok(array.patches().values().len() == array.len()
260 && array.patches().values().all_invalid()?);
261 }
262
263 array.patches().values().all_invalid()
264 }
265
266 #[allow(clippy::unnecessary_fallible_conversions)]
267 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
268 let indices = array.patches().indices().to_primitive()?;
269
270 if array.fill_scalar().is_null() {
271 let mut buffer = BooleanBufferBuilder::new(array.len());
273 buffer.append_n(array.len(), false);
275
276 match_each_integer_ptype!(indices.ptype(), |I| {
277 indices.as_slice::<I>().iter().for_each(|&index| {
278 buffer.set_bit(
279 usize::try_from(index).vortex_expect("Failed to cast to usize")
280 - array.patches().offset(),
281 true,
282 );
283 });
284 });
285
286 return Ok(Mask::from_buffer(buffer.finish()));
287 }
288
289 let mut buffer = BooleanBufferBuilder::new(array.len());
292 buffer.append_n(array.len(), true);
293
294 let values_validity = array.patches().values().validity_mask()?;
295 match_each_integer_ptype!(indices.ptype(), |I| {
296 indices
297 .as_slice::<I>()
298 .iter()
299 .enumerate()
300 .for_each(|(patch_idx, &index)| {
301 buffer.set_bit(
302 usize::try_from(index).vortex_expect("Failed to cast to usize")
303 - array.patches().offset(),
304 values_validity.value(patch_idx),
305 );
306 })
307 });
308
309 Ok(Mask::from_buffer(buffer.finish()))
310 }
311}
312
313#[cfg(test)]
314mod test {
315 use itertools::Itertools;
316 use vortex_array::IntoArray;
317 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
318 use vortex_array::compute::cast;
319 use vortex_array::validity::Validity;
320 use vortex_buffer::buffer;
321 use vortex_dtype::{DType, Nullability, PType};
322 use vortex_error::{VortexError, VortexUnwrap};
323 use vortex_scalar::{PrimitiveScalar, Scalar};
324
325 use super::*;
326
327 fn nullable_fill() -> Scalar {
328 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
329 }
330
331 fn non_nullable_fill() -> Scalar {
332 Scalar::from(42i32)
333 }
334
335 fn sparse_array(fill_value: Scalar) -> ArrayRef {
336 let mut values = buffer![100i32, 200, 300].into_array();
338 values = cast(&values, fill_value.dtype()).unwrap();
339
340 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
341 .unwrap()
342 .into_array()
343 }
344
345 #[test]
346 pub fn test_scalar_at() {
347 let array = sparse_array(nullable_fill());
348
349 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
350 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
351 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
352
353 let error = array.scalar_at(10).err().unwrap();
354 let VortexError::OutOfBounds(i, start, stop, _) = error else {
355 unreachable!()
356 };
357 assert_eq!(i, 10);
358 assert_eq!(start, 0);
359 assert_eq!(stop, 10);
360 }
361
362 #[test]
363 pub fn test_scalar_at_again() {
364 let arr = SparseArray::try_new(
365 ConstantArray::new(10u32, 1).into_array(),
366 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
367 100,
368 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
369 )
370 .unwrap();
371
372 assert_eq!(
373 PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
374 .unwrap()
375 .typed_value::<u32>(),
376 Some(1234)
377 );
378 assert!(arr.scalar_at(0).unwrap().is_null());
379 assert!(arr.scalar_at(99).unwrap().is_null());
380 }
381
382 #[test]
383 pub fn scalar_at_sliced() {
384 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
385 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
386 let error = sliced.scalar_at(5).err().unwrap();
387 let VortexError::OutOfBounds(i, start, stop, _) = error else {
388 unreachable!()
389 };
390 assert_eq!(i, 5);
391 assert_eq!(start, 0);
392 assert_eq!(stop, 5);
393 }
394
395 #[test]
396 pub fn validity_mask_sliced_null_fill() {
397 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
398 assert_eq!(
399 sliced.validity_mask().unwrap(),
400 Mask::from_iter(vec![true, false, false, true, false])
401 );
402 }
403
404 #[test]
405 pub fn validity_mask_sliced_nonnull_fill() {
406 let sliced = SparseArray::try_new(
407 buffer![2u64, 5, 8].into_array(),
408 ConstantArray::new(
409 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
410 3,
411 )
412 .into_array(),
413 10,
414 Scalar::primitive(1.0f32, Nullability::Nullable),
415 )
416 .unwrap()
417 .slice(2, 7)
418 .unwrap();
419
420 assert_eq!(
421 sliced.validity_mask().unwrap(),
422 Mask::from_iter(vec![false, true, true, false, true])
423 );
424 }
425
426 #[test]
427 pub fn scalar_at_sliced_twice() {
428 let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
429 assert_eq!(
430 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
431 100
432 );
433 let error = sliced_once.scalar_at(7).err().unwrap();
434 let VortexError::OutOfBounds(i, start, stop, _) = error else {
435 unreachable!()
436 };
437 assert_eq!(i, 7);
438 assert_eq!(start, 0);
439 assert_eq!(stop, 7);
440
441 let sliced_twice = sliced_once.slice(1, 6).unwrap();
442 assert_eq!(
443 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
444 200
445 );
446 let error2 = sliced_twice.scalar_at(5).err().unwrap();
447 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
448 unreachable!()
449 };
450 assert_eq!(i, 5);
451 assert_eq!(start, 0);
452 assert_eq!(stop, 5);
453 }
454
455 #[test]
456 pub fn sparse_validity_mask() {
457 let array = sparse_array(nullable_fill());
458 assert_eq!(
459 array
460 .validity_mask()
461 .unwrap()
462 .to_boolean_buffer()
463 .iter()
464 .collect_vec(),
465 [
466 false, false, true, false, false, true, false, false, true, false
467 ]
468 );
469 }
470
471 #[test]
472 fn sparse_validity_mask_non_null_fill() {
473 let array = sparse_array(non_nullable_fill());
474 assert!(array.validity_mask().unwrap().all_true());
475 }
476
477 #[test]
478 #[should_panic]
479 fn test_invalid_length() {
480 let values = buffer![15_u32, 135, 13531, 42].into_array();
481 let indices = buffer![10_u64, 11, 50, 100].into_array();
482
483 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
484 }
485
486 #[test]
487 fn test_valid_length() {
488 let values = buffer![15_u32, 135, 13531, 42].into_array();
489 let indices = buffer![10_u64, 11, 50, 100].into_array();
490
491 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
492 }
493
494 #[test]
495 fn encode_with_nulls() {
496 let sparse = SparseArray::encode(
497 &PrimitiveArray::new(
498 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
499 Validity::from_iter(vec![
500 true, true, false, true, false, true, false, true, true, false, true, false,
501 ]),
502 )
503 .into_array(),
504 None,
505 )
506 .vortex_unwrap();
507 let canonical = sparse.to_primitive().vortex_unwrap();
508 assert_eq!(
509 sparse.validity_mask().unwrap(),
510 Mask::from_iter(vec![
511 true, true, false, true, false, true, false, true, true, false, true, false,
512 ])
513 );
514 assert_eq!(
515 canonical.as_slice::<i32>(),
516 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
517 );
518 }
519}