1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::{NumCast, ToPrimitive};
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::{PValue, Scalar};
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{cast, filter, take};
18use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
19use crate::vtable::ValidityHelper;
20use crate::{Array, ArrayRef, IntoArray, ToCanonical};
21
22#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
23pub struct PatchesMetadata {
24 #[prost(uint64, tag = "1")]
25 len: u64,
26 #[prost(uint64, tag = "2")]
27 offset: u64,
28 #[prost(enumeration = "PType", tag = "3")]
29 indices_ptype: i32,
30}
31
32impl PatchesMetadata {
33 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
34 Self {
35 len: len as u64,
36 offset: offset as u64,
37 indices_ptype: indices_ptype as i32,
38 }
39 }
40
41 #[inline]
42 pub fn len(&self) -> usize {
43 usize::try_from(self.len).vortex_expect("len is a valid usize")
44 }
45
46 #[inline]
47 pub fn is_empty(&self) -> bool {
48 self.len == 0
49 }
50
51 #[inline]
52 pub fn offset(&self) -> usize {
53 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
54 }
55
56 #[inline]
57 pub fn indices_dtype(&self) -> DType {
58 assert!(
59 self.indices_ptype().is_unsigned_int(),
60 "Patch indices must be unsigned integers"
61 );
62 DType::Primitive(self.indices_ptype(), NonNullable)
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Patches {
69 array_len: usize,
70 offset: usize,
71 indices: ArrayRef,
72 values: ArrayRef,
73}
74
75impl Patches {
76 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
77 assert_eq!(
78 indices.len(),
79 values.len(),
80 "Patch indices and values must have the same length"
81 );
82 assert!(
83 indices.dtype().is_unsigned_int(),
84 "Patch indices must be unsigned integers"
85 );
86 assert!(
87 indices.len() <= array_len,
88 "Patch indices must be shorter than the array length"
89 );
90 assert!(!indices.is_empty(), "Patch indices must not be empty");
91 let max = usize::try_from(
92 &indices
93 .scalar_at(indices.len() - 1)
94 .vortex_expect("indices are not empty"),
95 )
96 .vortex_expect("indices must be a number");
97 assert!(
98 max - offset < array_len,
99 "Patch indices {max:?}, offset {offset} are longer than the array length {array_len}"
100 );
101 Self::new_unchecked(array_len, offset, indices, values)
102 }
103
104 pub fn new_unchecked(
114 array_len: usize,
115 offset: usize,
116 indices: ArrayRef,
117 values: ArrayRef,
118 ) -> Self {
119 Self {
120 array_len,
121 offset,
122 indices,
123 values,
124 }
125 }
126
127 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
129 (self.array_len, self.offset, self.indices, self.values)
130 }
131
132 pub fn array_len(&self) -> usize {
133 self.array_len
134 }
135
136 pub fn num_patches(&self) -> usize {
137 self.indices.len()
138 }
139
140 pub fn dtype(&self) -> &DType {
141 self.values.dtype()
142 }
143
144 pub fn indices(&self) -> &ArrayRef {
145 &self.indices
146 }
147
148 pub fn into_indices(self) -> ArrayRef {
149 self.indices
150 }
151
152 pub fn indices_mut(&mut self) -> &mut ArrayRef {
153 &mut self.indices
154 }
155
156 pub fn values(&self) -> &ArrayRef {
157 &self.values
158 }
159
160 pub fn into_values(self) -> ArrayRef {
161 self.values
162 }
163
164 pub fn values_mut(&mut self) -> &mut ArrayRef {
165 &mut self.values
166 }
167
168 pub fn offset(&self) -> usize {
169 self.offset
170 }
171
172 pub fn indices_ptype(&self) -> PType {
173 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
174 }
175
176 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
177 if self.indices.len() > len {
178 vortex_bail!(
179 "Patch indices {} are longer than the array length {}",
180 self.indices.len(),
181 len
182 );
183 }
184 if self.values.dtype() != dtype {
185 vortex_bail!(
186 "Patch values dtype {} does not match array dtype {}",
187 self.values.dtype(),
188 dtype
189 );
190 }
191 Ok(PatchesMetadata {
192 len: self.indices.len() as u64,
193 offset: self.offset as u64,
194 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
195 as i32,
196 })
197 }
198
199 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
200 Ok(Self::new_unchecked(
201 self.array_len,
202 self.offset,
203 self.indices,
204 cast(&self.values, values_dtype)?,
205 ))
206 }
207
208 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
210 if let Some(patch_idx) = self.search_index(index)?.to_found() {
211 self.values().scalar_at(patch_idx).map(Some)
212 } else {
213 Ok(None)
214 }
215 }
216
217 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
219 Ok(self.indices.as_primitive_typed().search_sorted(
220 &PValue::U64((index + self.offset) as u64),
221 SearchSortedSide::Left,
222 ))
223 }
224
225 pub fn search_sorted<T: Into<Scalar>>(
227 &self,
228 target: T,
229 side: SearchSortedSide,
230 ) -> VortexResult<SearchResult> {
231 let target = target.into();
232
233 let sr = if self.values().dtype().is_primitive() {
234 self.values()
235 .as_primitive_typed()
236 .search_sorted(&target.as_primitive().pvalue(), side)
237 } else {
238 self.values().search_sorted(&target, side)
239 };
240
241 let index_idx = sr.to_offsets_index(self.indices().len(), side);
242 let index = usize::try_from(&self.indices().scalar_at(index_idx)?)? - self.offset;
243 Ok(match sr {
244 SearchResult::Found(i) => SearchResult::Found(
246 if i == self.indices().len() || side == SearchSortedSide::Right {
247 index + 1
248 } else {
249 index
250 },
251 ),
252 SearchResult::NotFound(i) => {
254 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
255 }
256 })
257 }
258
259 pub fn min_index(&self) -> VortexResult<usize> {
261 Ok(usize::try_from(&self.indices().scalar_at(0)?)? - self.offset)
262 }
263
264 pub fn max_index(&self) -> VortexResult<usize> {
266 Ok(usize::try_from(&self.indices().scalar_at(self.indices().len() - 1)?)? - self.offset)
267 }
268
269 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
271 match mask.indices() {
272 AllOr::All => Ok(Some(self.clone())),
273 AllOr::None => Ok(None),
274 AllOr::Some(mask_indices) => {
275 let flat_indices = self.indices().to_primitive()?;
276 match_each_integer_ptype!(flat_indices.ptype(), |$I| {
277 filter_patches_with_mask(
278 flat_indices.as_slice::<$I>(),
279 self.offset(),
280 self.values(),
281 mask_indices,
282 )
283 })
284 }
285 }
286 }
287
288 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
290 let patch_start = self.search_index(start)?.to_index();
291 let patch_stop = self.search_index(stop)?.to_index();
292
293 if patch_start == patch_stop {
294 return Ok(None);
295 }
296
297 let values = self.values().slice(patch_start, patch_stop)?;
299 let indices = self.indices().slice(patch_start, patch_stop)?;
300
301 Ok(Some(Self::new(
302 stop - start,
303 start + self.offset(),
304 indices,
305 values,
306 )))
307 }
308
309 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
311
312 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
313 (self.num_patches() as f64 / take_indices.len() as f64)
314 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
315 }
316
317 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
319 if take_indices.is_empty() {
320 return Ok(None);
321 }
322 let take_indices = take_indices.to_primitive()?;
323 if self.is_map_faster_than_search(&take_indices) {
324 self.take_map(take_indices)
325 } else {
326 self.take_search(take_indices)
327 }
328 }
329
330 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
331 let indices = self.indices.to_primitive()?;
332 let new_length = take_indices.len();
333
334 let Some((new_indices, values_indices)) = match_each_integer_ptype!(indices.ptype(), |$INDICES| {
335 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
336 take_search::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset())?
337 })
338 }) else {
339 return Ok(None);
340 };
341
342 Ok(Some(Self::new(
343 new_length,
344 0,
345 new_indices,
346 take(self.values(), &values_indices)?,
347 )))
348 }
349
350 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
351 let indices = self.indices.to_primitive()?;
352 let new_length = take_indices.len();
353
354 let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
355 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
356 take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
357 })
358 }) else {
359 return Ok(None);
360 };
361
362 Ok(Some(Patches::new(
363 new_length,
364 0,
365 new_sparse_indices,
366 take(self.values(), &value_indices)?,
367 )))
368 }
369
370 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
371 where
372 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
373 {
374 let values = f(self.values)?;
375 if self.indices.len() != values.len() {
376 vortex_bail!(
377 "map_values must preserve length: expected {} received {}",
378 self.indices.len(),
379 values.len()
380 )
381 }
382 Ok(Self::new(self.array_len, self.offset, self.indices, values))
383 }
384}
385
386fn take_search<I: NativePType + NumCast + PartialOrd, T: NativePType + NumCast>(
387 indices: &[I],
388 take_indices: PrimitiveArray,
389 indices_offset: usize,
390) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
391where
392 usize: TryFrom<T>,
393 VortexError: From<<usize as TryFrom<T>>::Error>,
394{
395 let take_indices_validity = take_indices.validity();
396 let indices_offset = I::from(indices_offset).vortex_expect("indices_offset out of range");
397
398 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
399 .as_slice::<T>()
400 .iter()
401 .map(|v| {
402 match I::from(*v) {
403 None => {
404 SearchResult::NotFound(indices.len())
406 }
407 Some(v) => indices.search_sorted(&(v + indices_offset), SearchSortedSide::Left),
408 }
409 })
410 .enumerate()
411 .filter_map(|(idx_in_take, search_result)| {
412 search_result
413 .to_found()
414 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
415 })
416 .unzip();
417
418 if new_indices.is_empty() {
419 return Ok(None);
420 }
421
422 let new_indices = new_indices.into_array();
423 let values_validity = take_indices_validity.take(&new_indices)?;
424 Ok(Some((
425 new_indices,
426 PrimitiveArray::new(values_indices, values_validity).into_array(),
427 )))
428}
429
430fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
431 indices: &[I],
432 take_indices: PrimitiveArray,
433 indices_offset: usize,
434 min_index: usize,
435 max_index: usize,
436) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
437where
438 usize: TryFrom<T>,
439 VortexError: From<<I as TryFrom<usize>>::Error>,
440{
441 let take_indices_validity = take_indices.validity();
442 let take_indices = take_indices.as_slice::<T>();
443 let offset_i = I::try_from(indices_offset)?;
444
445 let sparse_index_to_value_index: HashMap<I, usize> = indices
446 .iter()
447 .copied()
448 .map(|idx| idx - offset_i)
449 .enumerate()
450 .map(|(value_index, sparse_index)| (sparse_index, value_index))
451 .collect();
452 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
453 .iter()
454 .copied()
455 .map(usize::try_from)
456 .process_results(|iter| {
457 iter.enumerate()
458 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
459 .filter_map(|(new_sparse_index, take_sparse_index)| {
460 sparse_index_to_value_index
461 .get(
462 &I::try_from(take_sparse_index)
463 .vortex_expect("take_sparse_index is between min and max index"),
464 )
465 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
466 })
467 .unzip()
468 })
469 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
470
471 if new_sparse_indices.is_empty() {
472 return Ok(None);
473 }
474
475 let new_sparse_indices = new_sparse_indices.into_array();
476 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
477 Ok(Some((
478 new_sparse_indices,
479 PrimitiveArray::new(value_indices, values_validity).into_array(),
480 )))
481}
482
483fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
489 patch_indices: &[T],
490 offset: usize,
491 patch_values: &dyn Array,
492 mask_indices: &[usize],
493) -> VortexResult<Option<Patches>> {
494 let true_count = mask_indices.len();
495 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
496 let mut new_mask_indices = Vec::with_capacity(true_count);
497
498 const STRIDE: usize = 4;
502
503 let mut mask_idx = 0usize;
504 let mut true_idx = 0usize;
505
506 while mask_idx < patch_indices.len() && true_idx < true_count {
507 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
514 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
516 let left_max = patch_indices[mask_idx + STRIDE]
517 .to_usize()
518 .vortex_expect("left_max")
519 - offset;
520 let right_min = mask_indices[true_idx];
521 let right_max = mask_indices[true_idx + STRIDE];
522
523 if left_min > right_max {
524 true_idx += STRIDE;
526 continue;
527 } else if right_min > left_max {
528 mask_idx += STRIDE;
529 continue;
530 } else {
531 }
533 }
534
535 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
538 let right = mask_indices[true_idx];
539
540 match left.cmp(&right) {
541 Ordering::Less => {
542 mask_idx += 1;
543 }
544 Ordering::Greater => {
545 true_idx += 1;
546 }
547 Ordering::Equal => {
548 new_mask_indices.push(mask_idx);
550 new_patch_indices.push(true_idx as u64);
551
552 mask_idx += 1;
553 true_idx += 1;
554 }
555 }
556 }
557
558 if new_mask_indices.is_empty() {
559 return Ok(None);
560 }
561
562 let new_patch_indices = new_patch_indices.into_array();
563 let new_patch_values = filter(
564 patch_values,
565 &Mask::from_indices(patch_values.len(), new_mask_indices),
566 )?;
567
568 Ok(Some(Patches::new(
569 true_count,
570 0,
571 new_patch_indices,
572 new_patch_values,
573 )))
574}
575
576#[cfg(test)]
577mod test {
578 use rstest::{fixture, rstest};
579 use vortex_buffer::buffer;
580 use vortex_mask::Mask;
581
582 use crate::arrays::PrimitiveArray;
583 use crate::patches::Patches;
584 use crate::search_sorted::{SearchResult, SearchSortedSide};
585 use crate::validity::Validity;
586 use crate::{IntoArray, ToCanonical};
587
588 #[test]
589 fn test_filter() {
590 let patches = Patches::new(
591 100,
592 0,
593 buffer![10u32, 11, 20].into_array(),
594 buffer![100, 110, 200].into_array(),
595 );
596
597 let filtered = patches
598 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
599 .unwrap()
600 .unwrap();
601
602 let indices = filtered.indices().to_primitive().unwrap();
603 let values = filtered.values().to_primitive().unwrap();
604 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
605 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
606 }
607
608 #[fixture]
609 fn patches() -> Patches {
610 Patches::new(
611 20,
612 0,
613 buffer![2u64, 9, 15].into_array(),
614 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
615 )
616 }
617
618 #[rstest]
619 fn search_larger_than(patches: Patches) {
620 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
621 assert_eq!(res, SearchResult::NotFound(16));
622 }
623
624 #[rstest]
625 fn search_less_than(patches: Patches) {
626 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
627 assert_eq!(res, SearchResult::NotFound(2));
628 }
629
630 #[rstest]
631 fn search_found(patches: Patches) {
632 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
633 assert_eq!(res, SearchResult::Found(9));
634 }
635
636 #[rstest]
637 fn search_not_found_right(patches: Patches) {
638 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
639 assert_eq!(res, SearchResult::NotFound(16));
640 }
641
642 #[rstest]
643 fn search_sliced(patches: Patches) {
644 let sliced = patches.slice(7, 20).unwrap().unwrap();
645 assert_eq!(
646 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
647 SearchResult::NotFound(2)
648 );
649 }
650
651 #[test]
652 fn search_right() {
653 let patches = Patches::new(
654 6,
655 0,
656 buffer![0u8, 1, 4, 5].into_array(),
657 buffer![-128i8, -98, 8, 50].into_array(),
658 );
659
660 assert_eq!(
661 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
662 SearchResult::Found(2)
663 );
664 assert_eq!(
665 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
666 SearchResult::Found(6),
667 );
668 assert_eq!(
669 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
670 SearchResult::NotFound(2),
671 );
672 assert_eq!(
673 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
674 SearchResult::NotFound(6)
675 );
676 }
677
678 #[test]
679 fn search_left() {
680 let patches = Patches::new(
681 20,
682 0,
683 buffer![0u64, 1, 17, 18, 19].into_array(),
684 buffer![11i32, 22, 33, 44, 55].into_array(),
685 );
686 assert_eq!(
687 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
688 SearchResult::NotFound(2)
689 );
690 assert_eq!(
691 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
692 SearchResult::NotFound(19)
693 );
694 }
695
696 #[rstest]
697 fn take_wit_nulls(patches: Patches) {
698 let taken = patches
699 .take(
700 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
701 .into_array(),
702 )
703 .unwrap()
704 .unwrap();
705 let primitive_values = taken.values().to_primitive().unwrap();
706 assert_eq!(taken.array_len(), 2);
707 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
708 assert_eq!(
709 primitive_values.validity_mask().unwrap(),
710 Mask::from_iter(vec![true])
711 );
712 }
713
714 #[test]
715 fn test_slice() {
716 let values = buffer![15_u32, 135, 13531, 42].into_array();
717 let indices = buffer![10_u64, 11, 50, 100].into_array();
718
719 let patches = Patches::new(101, 0, indices, values);
720
721 let sliced = patches.slice(15, 100).unwrap().unwrap();
722 assert_eq!(sliced.array_len(), 100 - 15);
723 let primitive = sliced.values().to_primitive().unwrap();
724
725 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
726 }
727
728 #[test]
729 fn doubly_sliced() {
730 let values = buffer![15_u32, 135, 13531, 42].into_array();
731 let indices = buffer![10_u64, 11, 50, 100].into_array();
732
733 let patches = Patches::new(101, 0, indices, values);
734
735 let sliced = patches.slice(15, 100).unwrap().unwrap();
736 assert_eq!(sliced.array_len(), 100 - 15);
737 let primitive = sliced.values().to_primitive().unwrap();
738
739 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
740
741 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
742 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
743
744 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
745 }
746}