1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10 Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, IntoArray,
11 ProstMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_mask::{AllOr, Mask};
17use vortex_scalar::Scalar;
18
19use crate::serde::SparseMetadata;
20
21mod canonical;
22mod compute;
23mod serde;
24mod variants;
25
26#[derive(Clone, Debug)]
27pub struct SparseArray {
28 patches: Patches,
29 fill_value: Scalar,
30 stats_set: ArrayStats,
31}
32
33try_from_array_ref!(SparseArray);
34
35#[derive(Debug)]
36pub struct SparseEncoding;
37impl Encoding for SparseEncoding {
38 type Array = SparseArray;
39 type Metadata = ProstMetadata<SparseMetadata>;
40}
41
42impl SparseArray {
43 pub fn try_new(
44 indices: ArrayRef,
45 values: ArrayRef,
46 len: usize,
47 fill_value: Scalar,
48 ) -> VortexResult<Self> {
49 Self::try_new_with_offset(indices, values, len, 0, fill_value)
50 }
51
52 pub(crate) fn try_new_with_offset(
53 indices: ArrayRef,
54 values: ArrayRef,
55 len: usize,
56 indices_offset: usize,
57 fill_value: Scalar,
58 ) -> VortexResult<Self> {
59 if indices.len() != values.len() {
60 vortex_bail!(
61 "Mismatched indices {} and values {} length",
62 indices.len(),
63 values.len()
64 );
65 }
66
67 if !indices.is_empty() {
68 let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
69
70 if last_index - indices_offset >= len {
71 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
72 }
73 }
74
75 let patches = Patches::new(len, indices_offset, indices, values);
76
77 Self::try_new_from_patches(patches, fill_value)
78 }
79
80 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
81 if fill_value.dtype() != patches.values().dtype() {
82 vortex_bail!(
83 "fill value, {:?}, should be instance of values dtype, {}",
84 fill_value,
85 patches.values().dtype(),
86 );
87 }
88 Ok(Self {
89 patches,
90 fill_value,
91 stats_set: Default::default(),
92 })
93 }
94
95 #[inline]
96 pub fn patches(&self) -> &Patches {
97 &self.patches
98 }
99
100 #[inline]
101 pub fn resolved_patches(&self) -> VortexResult<Patches> {
102 let (len, offset, indices, values) = self.patches().clone().into_parts();
103 let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
104 let indices = sub_scalar(&indices, indices_offset)?;
105 Ok(Patches::new(len, 0, indices, values))
106 }
107
108 #[inline]
109 pub fn fill_scalar(&self) -> &Scalar {
110 &self.fill_value
111 }
112
113 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
117 if let Some(fill_value) = fill_value.as_ref() {
118 if array.dtype() != fill_value.dtype() {
119 vortex_bail!(
120 "Array and fill value types must match. got {} and {}",
121 array.dtype(),
122 fill_value.dtype()
123 )
124 }
125 }
126 let mask = array.validity_mask()?;
127
128 if mask.all_false() {
129 return Ok(
131 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
132 );
133 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
134 let non_null_values = filter(array, &mask)?;
136 let non_null_indices = match mask.indices() {
137 AllOr::All => {
138 unreachable!("Mask is mostly null")
140 }
141 AllOr::None => {
142 unreachable!("Mask is mostly null but not all null")
144 }
145 AllOr::Some(values) => {
146 let buffer: Buffer<u32> = values
147 .iter()
148 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
149 .collect();
150
151 buffer.into_array()
152 }
153 };
154
155 return Ok(SparseArray::try_new(
156 non_null_indices,
157 non_null_values,
158 array.len(),
159 Scalar::null(array.dtype().clone()),
160 )?
161 .into_array());
162 }
163
164 let fill = if let Some(fill) = fill_value {
165 fill
166 } else {
167 let (top_pvalue, _) = array
169 .to_primitive()?
170 .top_value()?
171 .vortex_expect("Non empty or all null array");
172
173 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
174 };
175
176 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
177 let non_top_mask = Mask::from_buffer(
178 fill_null(
179 &compare(array, &fill_array, Operator::NotEq)?,
180 Scalar::bool(true, Nullability::NonNullable),
181 )?
182 .to_bool()?
183 .boolean_buffer()
184 .clone(),
185 );
186
187 let non_top_values = filter(array, &non_top_mask)?;
188
189 let indices: Buffer<u64> = match non_top_mask {
190 Mask::AllTrue(count) => {
191 (0u64..count as u64).collect()
193 }
194 Mask::AllFalse(_) => {
195 return Ok(fill_array);
197 }
198 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
199 };
200
201 SparseArray::try_new(
202 indices.into_array(),
203 non_top_values.into_array(),
204 array.len(),
205 fill,
206 )
207 .map(|a| a.into_array())
208 }
209}
210
211impl ArrayImpl for SparseArray {
212 type Encoding = SparseEncoding;
213
214 fn _len(&self) -> usize {
215 self.patches.array_len()
216 }
217
218 fn _dtype(&self) -> &DType {
219 self.fill_value.dtype()
220 }
221
222 fn _vtable(&self) -> VTableRef {
223 VTableRef::new_ref(&SparseEncoding)
224 }
225
226 fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
227 let patch_indices = children[0].clone();
228 let patch_values = children[1].clone();
229
230 let patches = Patches::new(
231 self.patches().array_len(),
232 self.patches().offset(),
233 patch_indices,
234 patch_values,
235 );
236
237 Self::try_new_from_patches(patches, self.fill_value.clone())
238 }
239}
240
241impl ArrayStatisticsImpl for SparseArray {
242 fn _stats_ref(&self) -> StatsSetRef<'_> {
243 self.stats_set.to_ref(self)
244 }
245}
246
247impl ArrayValidityImpl for SparseArray {
248 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
249 Ok(match self.patches().get_patched(index)? {
250 None => self.fill_scalar().is_valid(),
251 Some(patch_value) => patch_value.is_valid(),
252 })
253 }
254
255 fn _all_valid(&self) -> VortexResult<bool> {
256 if self.fill_scalar().is_null() {
257 return Ok(self.patches().values().len() == self.len()
259 && self.patches().values().all_valid()?);
260 }
261
262 self.patches().values().all_valid()
263 }
264
265 fn _all_invalid(&self) -> VortexResult<bool> {
266 if !self.fill_scalar().is_null() {
267 return Ok(self.patches().values().len() == self.len()
269 && self.patches().values().all_invalid()?);
270 }
271
272 self.patches().values().all_invalid()
273 }
274
275 fn _validity_mask(&self) -> VortexResult<Mask> {
276 let indices = self.patches().indices().to_primitive()?;
277
278 if self.fill_scalar().is_null() {
279 let mut buffer = BooleanBufferBuilder::new(self.len());
281 buffer.append_n(self.len(), false);
283
284 match_each_integer_ptype!(indices.ptype(), |$I| {
285 indices.as_slice::<$I>().into_iter().for_each(|&index| {
286 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), true);
287 });
288 });
289
290 return Ok(Mask::from_buffer(buffer.finish()));
291 }
292
293 let mut buffer = BooleanBufferBuilder::new(self.len());
296 buffer.append_n(self.len(), true);
297
298 let values_validity = self.patches().values().validity_mask()?;
299 match_each_integer_ptype!(indices.ptype(), |$I| {
300 indices.as_slice::<$I>()
301 .into_iter()
302 .enumerate()
303 .for_each(|(patch_idx, &index)| {
304 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), values_validity.value(patch_idx));
305 })
306 });
307
308 Ok(Mask::from_buffer(buffer.finish()))
309 }
310}
311
312#[cfg(test)]
313mod test {
314 use itertools::Itertools;
315 use vortex_array::IntoArray;
316 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
317 use vortex_array::compute::{cast, slice};
318 use vortex_array::validity::Validity;
319 use vortex_buffer::buffer;
320 use vortex_dtype::{DType, Nullability, PType};
321 use vortex_error::{VortexError, VortexUnwrap};
322 use vortex_scalar::{PrimitiveScalar, Scalar};
323
324 use super::*;
325
326 fn nullable_fill() -> Scalar {
327 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
328 }
329
330 fn non_nullable_fill() -> Scalar {
331 Scalar::from(42i32)
332 }
333
334 fn sparse_array(fill_value: Scalar) -> ArrayRef {
335 let mut values = buffer![100i32, 200, 300].into_array();
337 values = cast(&values, fill_value.dtype()).unwrap();
338
339 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
340 .unwrap()
341 .into_array()
342 }
343
344 #[test]
345 pub fn test_scalar_at() {
346 let array = sparse_array(nullable_fill());
347
348 assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
349 assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
350 assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
351
352 let error = scalar_at(&array, 10).err().unwrap();
353 let VortexError::OutOfBounds(i, start, stop, _) = error else {
354 unreachable!()
355 };
356 assert_eq!(i, 10);
357 assert_eq!(start, 0);
358 assert_eq!(stop, 10);
359 }
360
361 #[test]
362 pub fn test_scalar_at_again() {
363 let arr = SparseArray::try_new(
364 ConstantArray::new(10u32, 1).into_array(),
365 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
366 100,
367 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
368 )
369 .unwrap();
370
371 assert_eq!(
372 PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
373 .unwrap()
374 .typed_value::<u32>(),
375 Some(1234)
376 );
377 assert!(scalar_at(&arr, 0).unwrap().is_null());
378 assert!(scalar_at(&arr, 99).unwrap().is_null());
379 }
380
381 #[test]
382 pub fn scalar_at_sliced() {
383 let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
384 assert_eq!(
385 usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
386 100
387 );
388 let error = scalar_at(&sliced, 5).err().unwrap();
389 let VortexError::OutOfBounds(i, start, stop, _) = error else {
390 unreachable!()
391 };
392 assert_eq!(i, 5);
393 assert_eq!(start, 0);
394 assert_eq!(stop, 5);
395 }
396
397 #[test]
398 pub fn validity_mask_sliced_null_fill() {
399 let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
400 assert_eq!(
401 sliced.validity_mask().unwrap(),
402 Mask::from_iter(vec![true, false, false, true, false])
403 );
404 }
405
406 #[test]
407 pub fn validity_mask_sliced_nonnull_fill() {
408 let sliced = slice(
409 &SparseArray::try_new(
410 buffer![2u64, 5, 8].into_array(),
411 ConstantArray::new(
412 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
413 3,
414 )
415 .into_array(),
416 10,
417 Scalar::primitive(1.0f32, Nullability::Nullable),
418 )
419 .unwrap()
420 .into_array(),
421 2,
422 7,
423 )
424 .unwrap();
425
426 assert_eq!(
427 sliced.validity_mask().unwrap(),
428 Mask::from_iter(vec![false, true, true, false, true])
429 );
430 }
431
432 #[test]
433 pub fn scalar_at_sliced_twice() {
434 let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
435 assert_eq!(
436 usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
437 100
438 );
439 let error = scalar_at(&sliced_once, 7).err().unwrap();
440 let VortexError::OutOfBounds(i, start, stop, _) = error else {
441 unreachable!()
442 };
443 assert_eq!(i, 7);
444 assert_eq!(start, 0);
445 assert_eq!(stop, 7);
446
447 let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
448 assert_eq!(
449 usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
450 200
451 );
452 let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
453 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
454 unreachable!()
455 };
456 assert_eq!(i, 5);
457 assert_eq!(start, 0);
458 assert_eq!(stop, 5);
459 }
460
461 #[test]
462 pub fn sparse_validity_mask() {
463 let array = sparse_array(nullable_fill());
464 assert_eq!(
465 array
466 .validity_mask()
467 .unwrap()
468 .to_boolean_buffer()
469 .iter()
470 .collect_vec(),
471 [
472 false, false, true, false, false, true, false, false, true, false
473 ]
474 );
475 }
476
477 #[test]
478 fn sparse_validity_mask_non_null_fill() {
479 let array = sparse_array(non_nullable_fill());
480 assert!(array.validity_mask().unwrap().all_true());
481 }
482
483 #[test]
484 #[should_panic]
485 fn test_invalid_length() {
486 let values = buffer![15_u32, 135, 13531, 42].into_array();
487 let indices = buffer![10_u64, 11, 50, 100].into_array();
488
489 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
490 }
491
492 #[test]
493 fn test_valid_length() {
494 let values = buffer![15_u32, 135, 13531, 42].into_array();
495 let indices = buffer![10_u64, 11, 50, 100].into_array();
496
497 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
498 }
499
500 #[test]
501 fn encode_with_nulls() {
502 let sparse = SparseArray::encode(
503 &PrimitiveArray::new(
504 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
505 Validity::from_iter(vec![
506 true, true, false, true, false, true, false, true, true, false, true, false,
507 ]),
508 )
509 .into_array(),
510 None,
511 )
512 .vortex_unwrap();
513 let canonical = sparse.to_primitive().vortex_unwrap();
514 assert_eq!(
515 sparse.validity_mask().unwrap(),
516 Mask::from_iter(vec![
517 true, true, false, true, false, true, false, true, true, false, true, false,
518 ])
519 );
520 assert_eq!(
521 canonical.as_slice::<i32>(),
522 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
523 );
524 }
525}