1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10 Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, IntoArray,
11 RkyvMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_mask::{AllOr, Mask};
17use vortex_scalar::Scalar;
18
19use crate::serde::SparseMetadata;
20
21mod canonical;
22mod compute;
23mod serde;
24mod variants;
25
26#[derive(Clone, Debug)]
27pub struct SparseArray {
28 patches: Patches,
29 fill_value: Scalar,
30 stats_set: ArrayStats,
31}
32
33try_from_array_ref!(SparseArray);
34
35pub struct SparseEncoding;
36impl Encoding for SparseEncoding {
37 type Array = SparseArray;
38 type Metadata = RkyvMetadata<SparseMetadata>;
39}
40
41impl SparseArray {
42 pub fn try_new(
43 indices: ArrayRef,
44 values: ArrayRef,
45 len: usize,
46 fill_value: Scalar,
47 ) -> VortexResult<Self> {
48 Self::try_new_with_offset(indices, values, len, 0, fill_value)
49 }
50
51 pub(crate) fn try_new_with_offset(
52 indices: ArrayRef,
53 values: ArrayRef,
54 len: usize,
55 indices_offset: usize,
56 fill_value: Scalar,
57 ) -> VortexResult<Self> {
58 if indices.len() != values.len() {
59 vortex_bail!(
60 "Mismatched indices {} and values {} length",
61 indices.len(),
62 values.len()
63 );
64 }
65
66 if !indices.is_empty() {
67 let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
68
69 if last_index - indices_offset >= len {
70 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
71 }
72 }
73
74 let patches = Patches::new(len, indices_offset, indices, values);
75
76 Self::try_new_from_patches(patches, fill_value)
77 }
78
79 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
80 if fill_value.dtype() != patches.values().dtype() {
81 vortex_bail!(
82 "fill value, {:?}, should be instance of values dtype, {}",
83 fill_value,
84 patches.values().dtype(),
85 );
86 }
87 Ok(Self {
88 patches,
89 fill_value,
90 stats_set: Default::default(),
91 })
92 }
93
94 #[inline]
95 pub fn patches(&self) -> &Patches {
96 &self.patches
97 }
98
99 #[inline]
100 pub fn resolved_patches(&self) -> VortexResult<Patches> {
101 let (len, offset, indices, values) = self.patches().clone().into_parts();
102 let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
103 let indices = sub_scalar(&indices, indices_offset)?;
104 Ok(Patches::new(len, 0, indices, values))
105 }
106
107 #[inline]
108 pub fn fill_scalar(&self) -> &Scalar {
109 &self.fill_value
110 }
111
112 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
116 if let Some(fill_value) = fill_value.as_ref() {
117 if array.dtype() != fill_value.dtype() {
118 vortex_bail!(
119 "Array and fill value types must match. got {} and {}",
120 array.dtype(),
121 fill_value.dtype()
122 )
123 }
124 }
125 let mask = array.validity_mask()?;
126
127 if mask.all_false() {
128 return Ok(
130 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
131 );
132 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
133 let non_null_values = filter(array, &mask)?;
135 let non_null_indices = match mask.indices() {
136 AllOr::All => {
137 unreachable!("Mask is mostly null")
139 }
140 AllOr::None => {
141 unreachable!("Mask is mostly null but not all null")
143 }
144 AllOr::Some(values) => {
145 let buffer: Buffer<u32> = values
146 .iter()
147 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
148 .collect();
149
150 buffer.into_array()
151 }
152 };
153
154 return Ok(SparseArray::try_new(
155 non_null_indices,
156 non_null_values,
157 array.len(),
158 Scalar::null(array.dtype().clone()),
159 )?
160 .into_array());
161 }
162
163 let fill = if let Some(fill) = fill_value {
164 fill
165 } else {
166 let (top_pvalue, _) = array
168 .to_primitive()?
169 .top_value()?
170 .vortex_expect("Non empty or all null array");
171
172 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
173 };
174
175 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
176 let non_top_mask = Mask::from_buffer(
177 fill_null(
178 &compare(array, &fill_array, Operator::NotEq)?,
179 Scalar::bool(true, Nullability::NonNullable),
180 )?
181 .to_bool()?
182 .boolean_buffer()
183 .clone(),
184 );
185
186 let non_top_values = filter(array, &non_top_mask)?;
187
188 let indices: Buffer<u64> = match non_top_mask {
189 Mask::AllTrue(count) => {
190 (0u64..count as u64).collect()
192 }
193 Mask::AllFalse(_) => {
194 return Ok(fill_array);
196 }
197 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
198 };
199
200 SparseArray::try_new(
201 indices.into_array(),
202 non_top_values.into_array(),
203 array.len(),
204 fill,
205 )
206 .map(|a| a.into_array())
207 }
208}
209
210impl ArrayImpl for SparseArray {
211 type Encoding = SparseEncoding;
212
213 fn _len(&self) -> usize {
214 self.patches.array_len()
215 }
216
217 fn _dtype(&self) -> &DType {
218 self.fill_value.dtype()
219 }
220
221 fn _vtable(&self) -> VTableRef {
222 VTableRef::new_ref(&SparseEncoding)
223 }
224
225 fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
226 let patch_indices = children[0].clone();
227 let patch_values = children[1].clone();
228
229 let patches = Patches::new(
230 self.patches().array_len(),
231 self.patches().offset(),
232 patch_indices,
233 patch_values,
234 );
235
236 Self::try_new_from_patches(patches, self.fill_value.clone())
237 }
238}
239
240impl ArrayStatisticsImpl for SparseArray {
241 fn _stats_ref(&self) -> StatsSetRef<'_> {
242 self.stats_set.to_ref(self)
243 }
244}
245
246impl ArrayValidityImpl for SparseArray {
247 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
248 Ok(match self.patches().get_patched(index)? {
249 None => self.fill_scalar().is_valid(),
250 Some(patch_value) => patch_value.is_valid(),
251 })
252 }
253
254 fn _all_valid(&self) -> VortexResult<bool> {
255 if self.fill_scalar().is_null() {
256 return Ok(self.patches().values().len() == self.len()
258 && self.patches().values().all_valid()?);
259 }
260
261 self.patches().values().all_valid()
262 }
263
264 fn _all_invalid(&self) -> VortexResult<bool> {
265 if !self.fill_scalar().is_null() {
266 return Ok(self.patches().values().len() == self.len()
268 && self.patches().values().all_invalid()?);
269 }
270
271 self.patches().values().all_invalid()
272 }
273
274 fn _validity_mask(&self) -> VortexResult<Mask> {
275 let indices = self.patches().indices().to_primitive()?;
276
277 if self.fill_scalar().is_null() {
278 let mut buffer = BooleanBufferBuilder::new(self.len());
280 buffer.append_n(self.len(), false);
282
283 match_each_integer_ptype!(indices.ptype(), |$I| {
284 indices.as_slice::<$I>().into_iter().for_each(|&index| {
285 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), true);
286 });
287 });
288
289 return Ok(Mask::from_buffer(buffer.finish()));
290 }
291
292 let mut buffer = BooleanBufferBuilder::new(self.len());
295 buffer.append_n(self.len(), true);
296
297 let values_validity = self.patches().values().validity_mask()?;
298 match_each_integer_ptype!(indices.ptype(), |$I| {
299 indices.as_slice::<$I>()
300 .into_iter()
301 .enumerate()
302 .for_each(|(patch_idx, &index)| {
303 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), values_validity.value(patch_idx));
304 })
305 });
306
307 Ok(Mask::from_buffer(buffer.finish()))
308 }
309}
310
311#[cfg(test)]
312mod test {
313 use itertools::Itertools;
314 use vortex_array::IntoArray;
315 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
316 use vortex_array::compute::{slice, try_cast};
317 use vortex_array::validity::Validity;
318 use vortex_buffer::buffer;
319 use vortex_dtype::{DType, Nullability, PType};
320 use vortex_error::{VortexError, VortexUnwrap};
321 use vortex_scalar::{PrimitiveScalar, Scalar};
322
323 use super::*;
324
325 fn nullable_fill() -> Scalar {
326 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
327 }
328
329 fn non_nullable_fill() -> Scalar {
330 Scalar::from(42i32)
331 }
332
333 fn sparse_array(fill_value: Scalar) -> ArrayRef {
334 let mut values = buffer![100i32, 200, 300].into_array();
336 values = try_cast(&values, fill_value.dtype()).unwrap();
337
338 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
339 .unwrap()
340 .into_array()
341 }
342
343 #[test]
344 pub fn test_scalar_at() {
345 let array = sparse_array(nullable_fill());
346
347 assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
348 assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
349 assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
350
351 let error = scalar_at(&array, 10).err().unwrap();
352 let VortexError::OutOfBounds(i, start, stop, _) = error else {
353 unreachable!()
354 };
355 assert_eq!(i, 10);
356 assert_eq!(start, 0);
357 assert_eq!(stop, 10);
358 }
359
360 #[test]
361 pub fn test_scalar_at_again() {
362 let arr = SparseArray::try_new(
363 ConstantArray::new(10u32, 1).into_array(),
364 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
365 100,
366 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
367 )
368 .unwrap();
369
370 assert_eq!(
371 PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
372 .unwrap()
373 .typed_value::<u32>(),
374 Some(1234)
375 );
376 assert!(scalar_at(&arr, 0).unwrap().is_null());
377 assert!(scalar_at(&arr, 99).unwrap().is_null());
378 }
379
380 #[test]
381 pub fn scalar_at_sliced() {
382 let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
383 assert_eq!(
384 usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
385 100
386 );
387 let error = scalar_at(&sliced, 5).err().unwrap();
388 let VortexError::OutOfBounds(i, start, stop, _) = error else {
389 unreachable!()
390 };
391 assert_eq!(i, 5);
392 assert_eq!(start, 0);
393 assert_eq!(stop, 5);
394 }
395
396 #[test]
397 pub fn validity_mask_sliced_null_fill() {
398 let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
399 assert_eq!(
400 sliced.validity_mask().unwrap(),
401 Mask::from_iter(vec![true, false, false, true, false])
402 );
403 }
404
405 #[test]
406 pub fn validity_mask_sliced_nonnull_fill() {
407 let sliced = slice(
408 &SparseArray::try_new(
409 buffer![2u64, 5, 8].into_array(),
410 ConstantArray::new(
411 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
412 3,
413 )
414 .into_array(),
415 10,
416 Scalar::primitive(1.0f32, Nullability::Nullable),
417 )
418 .unwrap()
419 .into_array(),
420 2,
421 7,
422 )
423 .unwrap();
424
425 assert_eq!(
426 sliced.validity_mask().unwrap(),
427 Mask::from_iter(vec![false, true, true, false, true])
428 );
429 }
430
431 #[test]
432 pub fn scalar_at_sliced_twice() {
433 let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
434 assert_eq!(
435 usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
436 100
437 );
438 let error = scalar_at(&sliced_once, 7).err().unwrap();
439 let VortexError::OutOfBounds(i, start, stop, _) = error else {
440 unreachable!()
441 };
442 assert_eq!(i, 7);
443 assert_eq!(start, 0);
444 assert_eq!(stop, 7);
445
446 let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
447 assert_eq!(
448 usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
449 200
450 );
451 let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
452 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
453 unreachable!()
454 };
455 assert_eq!(i, 5);
456 assert_eq!(start, 0);
457 assert_eq!(stop, 5);
458 }
459
460 #[test]
461 pub fn sparse_validity_mask() {
462 let array = sparse_array(nullable_fill());
463 assert_eq!(
464 array
465 .validity_mask()
466 .unwrap()
467 .to_boolean_buffer()
468 .iter()
469 .collect_vec(),
470 [
471 false, false, true, false, false, true, false, false, true, false
472 ]
473 );
474 }
475
476 #[test]
477 fn sparse_validity_mask_non_null_fill() {
478 let array = sparse_array(non_nullable_fill());
479 assert!(array.validity_mask().unwrap().all_true());
480 }
481
482 #[test]
483 #[should_panic]
484 fn test_invalid_length() {
485 let values = buffer![15_u32, 135, 13531, 42].into_array();
486 let indices = buffer![10_u64, 11, 50, 100].into_array();
487
488 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
489 }
490
491 #[test]
492 fn test_valid_length() {
493 let values = buffer![15_u32, 135, 13531, 42].into_array();
494 let indices = buffer![10_u64, 11, 50, 100].into_array();
495
496 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
497 }
498
499 #[test]
500 fn encode_with_nulls() {
501 let sparse = SparseArray::encode(
502 &PrimitiveArray::new(
503 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
504 Validity::from_iter(vec![
505 true, true, false, true, false, true, false, true, true, false, true, false,
506 ]),
507 )
508 .into_array(),
509 None,
510 )
511 .vortex_unwrap();
512 let canonical = sparse.to_primitive().vortex_unwrap();
513 assert_eq!(
514 sparse.validity_mask().unwrap(),
515 Mask::from_iter(vec![
516 true, true, false, true, false, true, false, true, true, false, true, false,
517 ])
518 );
519 assert_eq!(
520 canonical.as_slice::<i32>(),
521 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
522 );
523 }
524}