1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::ToPrimitive;
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{
18 SearchResult, SearchSortedSide, filter, scalar_at, search_sorted, search_sorted_usize,
19 search_sorted_usize_many, slice, take, try_cast,
20};
21use crate::variants::PrimitiveArrayTrait;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical};
23
24#[derive(
25 Copy,
26 Clone,
27 Debug,
28 Serialize,
29 Deserialize,
30 rkyv::Archive,
31 rkyv::Serialize,
32 rkyv::Deserialize,
33 rkyv::bytecheck::CheckBytes,
34)]
35#[bytecheck(crate = rkyv::bytecheck)]
36#[repr(C)]
37pub struct PatchesMetadata {
38 len: usize,
39 offset: usize,
40 indices_ptype: PType,
41}
42
43impl PatchesMetadata {
44 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
45 Self {
46 len,
47 offset,
48 indices_ptype,
49 }
50 }
51
52 #[inline]
53 pub fn len(&self) -> usize {
54 self.len
55 }
56
57 #[inline]
58 pub fn is_empty(&self) -> bool {
59 self.len == 0
60 }
61
62 #[inline]
63 pub fn offset(&self) -> usize {
64 self.offset
65 }
66
67 #[inline]
68 pub fn indices_dtype(&self) -> DType {
69 assert!(
70 self.indices_ptype.is_unsigned_int(),
71 "Patch indices must be unsigned integers"
72 );
73 DType::Primitive(self.indices_ptype, NonNullable)
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Patches {
80 array_len: usize,
81 offset: usize,
82 indices: ArrayRef,
83 values: ArrayRef,
84}
85
86impl Patches {
87 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
88 assert_eq!(
89 indices.len(),
90 values.len(),
91 "Patch indices and values must have the same length"
92 );
93 assert!(
94 indices.dtype().is_unsigned_int(),
95 "Patch indices must be unsigned integers"
96 );
97 assert!(
98 indices.len() <= array_len,
99 "Patch indices must be shorter than the array length"
100 );
101 assert!(!indices.is_empty(), "Patch indices must not be empty");
102 let max = usize::try_from(
103 &scalar_at(&indices, indices.len() - 1).vortex_expect("indices are not empty"),
104 )
105 .vortex_expect("indices must be a number");
106 assert!(
107 max - offset < array_len,
108 "Patch indices {:?}, offset {} are longer than the array length {}",
109 max,
110 offset,
111 array_len
112 );
113 Self::new_unchecked(array_len, offset, indices, values)
114 }
115
116 pub fn new_unchecked(
126 array_len: usize,
127 offset: usize,
128 indices: ArrayRef,
129 values: ArrayRef,
130 ) -> Self {
131 Self {
132 array_len,
133 offset,
134 indices,
135 values,
136 }
137 }
138
139 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
141 (self.array_len, self.offset, self.indices, self.values)
142 }
143
144 pub fn array_len(&self) -> usize {
145 self.array_len
146 }
147
148 pub fn num_patches(&self) -> usize {
149 self.indices.len()
150 }
151
152 pub fn dtype(&self) -> &DType {
153 self.values.dtype()
154 }
155
156 pub fn indices(&self) -> &ArrayRef {
157 &self.indices
158 }
159
160 pub fn into_indices(self) -> ArrayRef {
161 self.indices
162 }
163
164 pub fn indices_mut(&mut self) -> &mut ArrayRef {
165 &mut self.indices
166 }
167
168 pub fn values(&self) -> &ArrayRef {
169 &self.values
170 }
171
172 pub fn into_values(self) -> ArrayRef {
173 self.values
174 }
175
176 pub fn values_mut(&mut self) -> &mut ArrayRef {
177 &mut self.values
178 }
179
180 pub fn offset(&self) -> usize {
181 self.offset
182 }
183
184 pub fn indices_ptype(&self) -> PType {
185 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
186 }
187
188 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
189 if self.indices.len() > len {
190 vortex_bail!(
191 "Patch indices {} are longer than the array length {}",
192 self.indices.len(),
193 len
194 );
195 }
196 if self.values.dtype() != dtype {
197 vortex_bail!(
198 "Patch values dtype {} does not match array dtype {}",
199 self.values.dtype(),
200 dtype
201 );
202 }
203 Ok(PatchesMetadata {
204 len: self.indices.len(),
205 offset: self.offset,
206 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices"),
207 })
208 }
209
210 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
211 Ok(Self::new_unchecked(
212 self.array_len,
213 self.offset,
214 self.indices,
215 try_cast(&self.values, values_dtype)?,
216 ))
217 }
218
219 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
221 if let Some(patch_idx) = self.search_index(index)?.to_found() {
222 scalar_at(self.values(), patch_idx).map(Some)
223 } else {
224 Ok(None)
225 }
226 }
227
228 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
230 search_sorted_usize(&self.indices, index + self.offset, SearchSortedSide::Left)
231 }
232
233 pub fn search_sorted<T: Into<Scalar>>(
235 &self,
236 target: T,
237 side: SearchSortedSide,
238 ) -> VortexResult<SearchResult> {
239 search_sorted(self.values(), target.into(), side).and_then(|sr| {
240 let index_idx = sr.to_offsets_index(self.indices().len(), side);
241 let index = usize::try_from(&scalar_at(self.indices(), index_idx)?)? - self.offset;
242 Ok(match sr {
243 SearchResult::Found(i) => SearchResult::Found(
245 if i == self.indices().len() || side == SearchSortedSide::Right {
246 index + 1
247 } else {
248 index
249 },
250 ),
251 SearchResult::NotFound(i) => {
253 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
254 }
255 })
256 })
257 }
258
259 pub fn min_index(&self) -> VortexResult<usize> {
261 Ok(usize::try_from(&scalar_at(self.indices(), 0)?)? - self.offset)
262 }
263
264 pub fn max_index(&self) -> VortexResult<usize> {
266 Ok(usize::try_from(&scalar_at(self.indices(), self.indices().len() - 1)?)? - self.offset)
267 }
268
269 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
271 match mask.indices() {
272 AllOr::All => Ok(Some(self.clone())),
273 AllOr::None => Ok(None),
274 AllOr::Some(mask_indices) => {
275 let flat_indices = self.indices().to_primitive()?;
276 match_each_integer_ptype!(flat_indices.ptype(), |$I| {
277 filter_patches_with_mask(
278 flat_indices.as_slice::<$I>(),
279 self.offset(),
280 self.values(),
281 mask_indices,
282 )
283 })
284 }
285 }
286 }
287
288 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
290 let patch_start = self.search_index(start)?.to_index();
291 let patch_stop = self.search_index(stop)?.to_index();
292
293 if patch_start == patch_stop {
294 return Ok(None);
295 }
296
297 let values = slice(self.values(), patch_start, patch_stop)?;
299 let indices = slice(self.indices(), patch_start, patch_stop)?;
300
301 Ok(Some(Self::new(
302 stop - start,
303 start + self.offset(),
304 indices,
305 values,
306 )))
307 }
308
309 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
311
312 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
313 (self.num_patches() as f64 / take_indices.len() as f64)
314 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
315 }
316
317 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
319 if take_indices.is_empty() {
320 return Ok(None);
321 }
322 let take_indices = take_indices.to_primitive()?;
323 if self.is_map_faster_than_search(&take_indices) {
324 self.take_map(take_indices)
325 } else {
326 self.take_search(take_indices)
327 }
328 }
329
330 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
331 let new_length = take_indices.len();
332
333 let Some((new_indices, values_indices)) = match_each_integer_ptype!(take_indices.ptype(), |$I| {
334 take_search::<$I>(self.indices(), take_indices, self.offset())?
335 }) else {
336 return Ok(None);
337 };
338
339 Ok(Some(Self::new(
340 new_length,
341 0,
342 new_indices,
343 take(self.values(), &values_indices)?,
344 )))
345 }
346
347 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
348 let indices = self.indices.to_primitive()?;
349 let new_length = take_indices.len();
350
351 let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
352 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
353 take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
354 })
355 }) else {
356 return Ok(None);
357 };
358
359 Ok(Some(Patches::new(
360 new_length,
361 0,
362 new_sparse_indices,
363 take(self.values(), &value_indices)?,
364 )))
365 }
366
367 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
368 where
369 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
370 {
371 let values = f(self.values)?;
372 if self.indices.len() != values.len() {
373 vortex_bail!(
374 "map_values must preserve length: expected {} received {}",
375 self.indices.len(),
376 values.len()
377 )
378 }
379 Ok(Self::new(self.array_len, self.offset, self.indices, values))
380 }
381}
382
383fn take_search<T: NativePType + TryFrom<usize>>(
384 indices: &dyn Array,
385 take_indices: PrimitiveArray,
386 indices_offset: usize,
387) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
388where
389 usize: TryFrom<T>,
390 VortexError: From<<usize as TryFrom<T>>::Error>,
391{
392 let take_indices_validity = take_indices.validity();
393 let take_indices = take_indices
394 .as_slice::<T>()
395 .iter()
396 .copied()
397 .map(usize::try_from)
398 .map_ok(|idx| idx + indices_offset)
399 .collect::<Result<Vec<_>, _>>()?;
400
401 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) =
402 search_sorted_usize_many(indices, &take_indices, SearchSortedSide::Left)?
403 .iter()
404 .enumerate()
405 .filter_map(|(idx_in_take, search_result)| {
406 search_result
407 .to_found()
408 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
409 })
410 .unzip();
411
412 if new_indices.is_empty() {
413 return Ok(None);
414 }
415
416 let new_indices = new_indices.into_array();
417 let values_validity = take_indices_validity.take(&new_indices)?;
418 Ok(Some((
419 new_indices,
420 PrimitiveArray::new(values_indices, values_validity).into_array(),
421 )))
422}
423
424fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
425 indices: &[I],
426 take_indices: PrimitiveArray,
427 indices_offset: usize,
428 min_index: usize,
429 max_index: usize,
430) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
431where
432 usize: TryFrom<T>,
433 VortexError: From<<I as TryFrom<usize>>::Error>,
434{
435 let take_indices_validity = take_indices.validity();
436 let take_indices = take_indices.as_slice::<T>();
437 let offset_i = I::try_from(indices_offset)?;
438
439 let sparse_index_to_value_index: HashMap<I, usize> = indices
440 .iter()
441 .copied()
442 .map(|idx| idx - offset_i)
443 .enumerate()
444 .map(|(value_index, sparse_index)| (sparse_index, value_index))
445 .collect();
446 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
447 .iter()
448 .copied()
449 .map(usize::try_from)
450 .process_results(|iter| {
451 iter.enumerate()
452 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
453 .filter_map(|(new_sparse_index, take_sparse_index)| {
454 sparse_index_to_value_index
455 .get(
456 &I::try_from(take_sparse_index)
457 .vortex_expect("take_sparse_index is between min and max index"),
458 )
459 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
460 })
461 .unzip()
462 })
463 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
464
465 if new_sparse_indices.is_empty() {
466 return Ok(None);
467 }
468
469 let new_sparse_indices = new_sparse_indices.into_array();
470 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
471 Ok(Some((
472 new_sparse_indices,
473 PrimitiveArray::new(value_indices, values_validity).into_array(),
474 )))
475}
476
477fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
483 patch_indices: &[T],
484 offset: usize,
485 patch_values: &dyn Array,
486 mask_indices: &[usize],
487) -> VortexResult<Option<Patches>> {
488 let true_count = mask_indices.len();
489 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
490 let mut new_mask_indices = Vec::with_capacity(true_count);
491
492 const STRIDE: usize = 4;
496
497 let mut mask_idx = 0usize;
498 let mut true_idx = 0usize;
499
500 while mask_idx < patch_indices.len() && true_idx < true_count {
501 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
508 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
510 let left_max = patch_indices[mask_idx + STRIDE]
511 .to_usize()
512 .vortex_expect("left_max")
513 - offset;
514 let right_min = mask_indices[true_idx];
515 let right_max = mask_indices[true_idx + STRIDE];
516
517 if left_min > right_max {
518 true_idx += STRIDE;
520 continue;
521 } else if right_min > left_max {
522 mask_idx += STRIDE;
523 continue;
524 } else {
525 }
527 }
528
529 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
532 let right = mask_indices[true_idx];
533
534 match left.cmp(&right) {
535 Ordering::Less => {
536 mask_idx += 1;
537 }
538 Ordering::Greater => {
539 true_idx += 1;
540 }
541 Ordering::Equal => {
542 new_mask_indices.push(mask_idx);
544 new_patch_indices.push(true_idx as u64);
545
546 mask_idx += 1;
547 true_idx += 1;
548 }
549 }
550 }
551
552 if new_mask_indices.is_empty() {
553 return Ok(None);
554 }
555
556 let new_patch_indices = new_patch_indices.into_array();
557 let new_patch_values = filter(
558 patch_values,
559 &Mask::from_indices(patch_values.len(), new_mask_indices),
560 )?;
561
562 Ok(Some(Patches::new(
563 true_count,
564 0,
565 new_patch_indices,
566 new_patch_values,
567 )))
568}
569
570#[cfg(test)]
571mod test {
572 use rstest::{fixture, rstest};
573 use vortex_buffer::buffer;
574 use vortex_mask::Mask;
575
576 use crate::array::Array;
577 use crate::arrays::PrimitiveArray;
578 use crate::compute::{SearchResult, SearchSortedSide};
579 use crate::patches::Patches;
580 use crate::validity::Validity;
581 use crate::{IntoArray, ToCanonical};
582
583 #[test]
584 fn test_filter() {
585 let patches = Patches::new(
586 100,
587 0,
588 buffer![10u32, 11, 20].into_array(),
589 buffer![100, 110, 200].into_array(),
590 );
591
592 let filtered = patches
593 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
594 .unwrap()
595 .unwrap();
596
597 let indices = filtered.indices().to_primitive().unwrap();
598 let values = filtered.values().to_primitive().unwrap();
599 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
600 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
601 }
602
603 #[fixture]
604 fn patches() -> Patches {
605 Patches::new(
606 20,
607 0,
608 buffer![2u64, 9, 15].into_array(),
609 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
610 )
611 }
612
613 #[rstest]
614 fn search_larger_than(patches: Patches) {
615 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
616 assert_eq!(res, SearchResult::NotFound(16));
617 }
618
619 #[rstest]
620 fn search_less_than(patches: Patches) {
621 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
622 assert_eq!(res, SearchResult::NotFound(2));
623 }
624
625 #[rstest]
626 fn search_found(patches: Patches) {
627 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
628 assert_eq!(res, SearchResult::Found(9));
629 }
630
631 #[rstest]
632 fn search_not_found_right(patches: Patches) {
633 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
634 assert_eq!(res, SearchResult::NotFound(16));
635 }
636
637 #[rstest]
638 fn search_sliced(patches: Patches) {
639 let sliced = patches.slice(7, 20).unwrap().unwrap();
640 assert_eq!(
641 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
642 SearchResult::NotFound(2)
643 );
644 }
645
646 #[test]
647 fn search_right() {
648 let patches = Patches::new(
649 6,
650 0,
651 buffer![0u8, 1, 4, 5].into_array(),
652 buffer![-128i8, -98, 8, 50].into_array(),
653 );
654
655 assert_eq!(
656 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
657 SearchResult::Found(2)
658 );
659 assert_eq!(
660 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
661 SearchResult::Found(6),
662 );
663 assert_eq!(
664 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
665 SearchResult::NotFound(2),
666 );
667 assert_eq!(
668 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
669 SearchResult::NotFound(6)
670 );
671 }
672
673 #[test]
674 fn search_left() {
675 let patches = Patches::new(
676 20,
677 0,
678 buffer![0u64, 1, 17, 18, 19].into_array(),
679 buffer![11i32, 22, 33, 44, 55].into_array(),
680 );
681 assert_eq!(
682 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
683 SearchResult::NotFound(2)
684 );
685 assert_eq!(
686 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
687 SearchResult::NotFound(19)
688 );
689 }
690
691 #[rstest]
692 fn take_wit_nulls(patches: Patches) {
693 let taken = patches
694 .take(
695 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
696 .into_array(),
697 )
698 .unwrap()
699 .unwrap();
700 let primitive_values = taken.values().to_primitive().unwrap();
701 assert_eq!(taken.array_len(), 2);
702 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
703 assert_eq!(
704 primitive_values.validity_mask().unwrap(),
705 Mask::from_iter(vec![true])
706 );
707 }
708
709 #[test]
710 fn test_slice() {
711 let values = buffer![15_u32, 135, 13531, 42].into_array();
712 let indices = buffer![10_u64, 11, 50, 100].into_array();
713
714 let patches = Patches::new(101, 0, indices, values);
715
716 let sliced = patches.slice(15, 100).unwrap().unwrap();
717 assert_eq!(sliced.array_len(), 100 - 15);
718 let primitive = sliced.values().to_primitive().unwrap();
719
720 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
721 }
722
723 #[test]
724 fn doubly_sliced() {
725 let values = buffer![15_u32, 135, 13531, 42].into_array();
726 let indices = buffer![10_u64, 11, 50, 100].into_array();
727
728 let patches = Patches::new(101, 0, indices, values);
729
730 let sliced = patches.slice(15, 100).unwrap().unwrap();
731 assert_eq!(sliced.array_len(), 100 - 15);
732 let primitive = sliced.values().to_primitive().unwrap();
733
734 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
735
736 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
737 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
738
739 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
740 }
741}