1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use itertools::Itertools as _;
6use num_traits::ToPrimitive;
7use serde::{Deserialize, Serialize};
8use vortex_buffer::BufferMut;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype};
11use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15use crate::aliases::hash_map::HashMap;
16use crate::arrays::PrimitiveArray;
17use crate::compute::{
18 SearchResult, SearchSortedSide, cast, filter, scalar_at, search_sorted, search_sorted_usize,
19 search_sorted_usize_many, slice, take,
20};
21use crate::variants::PrimitiveArrayTrait;
22use crate::{Array, ArrayRef, IntoArray, ToCanonical};
23
24#[derive(Copy, Clone, Serialize, Deserialize, prost::Message)]
25pub struct PatchesMetadata {
26 #[prost(uint64, tag = "1")]
27 len: u64,
28 #[prost(uint64, tag = "2")]
29 offset: u64,
30 #[prost(enumeration = "PType", tag = "3")]
31 indices_ptype: i32,
32}
33
34impl PatchesMetadata {
35 pub fn new(len: usize, offset: usize, indices_ptype: PType) -> Self {
36 Self {
37 len: len as u64,
38 offset: offset as u64,
39 indices_ptype: indices_ptype as i32,
40 }
41 }
42
43 #[inline]
44 pub fn len(&self) -> usize {
45 usize::try_from(self.len).vortex_expect("len is a valid usize")
46 }
47
48 #[inline]
49 pub fn is_empty(&self) -> bool {
50 self.len == 0
51 }
52
53 #[inline]
54 pub fn offset(&self) -> usize {
55 usize::try_from(self.offset).vortex_expect("offset is a valid usize")
56 }
57
58 #[inline]
59 pub fn indices_dtype(&self) -> DType {
60 assert!(
61 self.indices_ptype().is_unsigned_int(),
62 "Patch indices must be unsigned integers"
63 );
64 DType::Primitive(self.indices_ptype(), NonNullable)
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct Patches {
71 array_len: usize,
72 offset: usize,
73 indices: ArrayRef,
74 values: ArrayRef,
75}
76
77impl Patches {
78 pub fn new(array_len: usize, offset: usize, indices: ArrayRef, values: ArrayRef) -> Self {
79 assert_eq!(
80 indices.len(),
81 values.len(),
82 "Patch indices and values must have the same length"
83 );
84 assert!(
85 indices.dtype().is_unsigned_int(),
86 "Patch indices must be unsigned integers"
87 );
88 assert!(
89 indices.len() <= array_len,
90 "Patch indices must be shorter than the array length"
91 );
92 assert!(!indices.is_empty(), "Patch indices must not be empty");
93 let max = usize::try_from(
94 &scalar_at(&indices, indices.len() - 1).vortex_expect("indices are not empty"),
95 )
96 .vortex_expect("indices must be a number");
97 assert!(
98 max - offset < array_len,
99 "Patch indices {:?}, offset {} are longer than the array length {}",
100 max,
101 offset,
102 array_len
103 );
104 Self::new_unchecked(array_len, offset, indices, values)
105 }
106
107 pub fn new_unchecked(
117 array_len: usize,
118 offset: usize,
119 indices: ArrayRef,
120 values: ArrayRef,
121 ) -> Self {
122 Self {
123 array_len,
124 offset,
125 indices,
126 values,
127 }
128 }
129
130 pub fn into_parts(self) -> (usize, usize, ArrayRef, ArrayRef) {
132 (self.array_len, self.offset, self.indices, self.values)
133 }
134
135 pub fn array_len(&self) -> usize {
136 self.array_len
137 }
138
139 pub fn num_patches(&self) -> usize {
140 self.indices.len()
141 }
142
143 pub fn dtype(&self) -> &DType {
144 self.values.dtype()
145 }
146
147 pub fn indices(&self) -> &ArrayRef {
148 &self.indices
149 }
150
151 pub fn into_indices(self) -> ArrayRef {
152 self.indices
153 }
154
155 pub fn indices_mut(&mut self) -> &mut ArrayRef {
156 &mut self.indices
157 }
158
159 pub fn values(&self) -> &ArrayRef {
160 &self.values
161 }
162
163 pub fn into_values(self) -> ArrayRef {
164 self.values
165 }
166
167 pub fn values_mut(&mut self) -> &mut ArrayRef {
168 &mut self.values
169 }
170
171 pub fn offset(&self) -> usize {
172 self.offset
173 }
174
175 pub fn indices_ptype(&self) -> PType {
176 PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
177 }
178
179 pub fn to_metadata(&self, len: usize, dtype: &DType) -> VortexResult<PatchesMetadata> {
180 if self.indices.len() > len {
181 vortex_bail!(
182 "Patch indices {} are longer than the array length {}",
183 self.indices.len(),
184 len
185 );
186 }
187 if self.values.dtype() != dtype {
188 vortex_bail!(
189 "Patch values dtype {} does not match array dtype {}",
190 self.values.dtype(),
191 dtype
192 );
193 }
194 Ok(PatchesMetadata {
195 len: self.indices.len() as u64,
196 offset: self.offset as u64,
197 indices_ptype: PType::try_from(self.indices.dtype()).vortex_expect("primitive indices")
198 as i32,
199 })
200 }
201
202 pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
203 Ok(Self::new_unchecked(
204 self.array_len,
205 self.offset,
206 self.indices,
207 cast(&self.values, values_dtype)?,
208 ))
209 }
210
211 pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
213 if let Some(patch_idx) = self.search_index(index)?.to_found() {
214 scalar_at(self.values(), patch_idx).map(Some)
215 } else {
216 Ok(None)
217 }
218 }
219
220 pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
222 search_sorted_usize(&self.indices, index + self.offset, SearchSortedSide::Left)
223 }
224
225 pub fn search_sorted<T: Into<Scalar>>(
227 &self,
228 target: T,
229 side: SearchSortedSide,
230 ) -> VortexResult<SearchResult> {
231 search_sorted(self.values(), target.into(), side).and_then(|sr| {
232 let index_idx = sr.to_offsets_index(self.indices().len(), side);
233 let index = usize::try_from(&scalar_at(self.indices(), index_idx)?)? - self.offset;
234 Ok(match sr {
235 SearchResult::Found(i) => SearchResult::Found(
237 if i == self.indices().len() || side == SearchSortedSide::Right {
238 index + 1
239 } else {
240 index
241 },
242 ),
243 SearchResult::NotFound(i) => {
245 SearchResult::NotFound(if i == 0 { index } else { index + 1 })
246 }
247 })
248 })
249 }
250
251 pub fn min_index(&self) -> VortexResult<usize> {
253 Ok(usize::try_from(&scalar_at(self.indices(), 0)?)? - self.offset)
254 }
255
256 pub fn max_index(&self) -> VortexResult<usize> {
258 Ok(usize::try_from(&scalar_at(self.indices(), self.indices().len() - 1)?)? - self.offset)
259 }
260
261 pub fn filter(&self, mask: &Mask) -> VortexResult<Option<Self>> {
263 match mask.indices() {
264 AllOr::All => Ok(Some(self.clone())),
265 AllOr::None => Ok(None),
266 AllOr::Some(mask_indices) => {
267 let flat_indices = self.indices().to_primitive()?;
268 match_each_integer_ptype!(flat_indices.ptype(), |$I| {
269 filter_patches_with_mask(
270 flat_indices.as_slice::<$I>(),
271 self.offset(),
272 self.values(),
273 mask_indices,
274 )
275 })
276 }
277 }
278 }
279
280 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Option<Self>> {
282 let patch_start = self.search_index(start)?.to_index();
283 let patch_stop = self.search_index(stop)?.to_index();
284
285 if patch_start == patch_stop {
286 return Ok(None);
287 }
288
289 let values = slice(self.values(), patch_start, patch_stop)?;
291 let indices = slice(self.indices(), patch_start, patch_stop)?;
292
293 Ok(Some(Self::new(
294 stop - start,
295 start + self.offset(),
296 indices,
297 values,
298 )))
299 }
300
301 const PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN: f64 = 5.0;
303
304 fn is_map_faster_than_search(&self, take_indices: &PrimitiveArray) -> bool {
305 (self.num_patches() as f64 / take_indices.len() as f64)
306 < Self::PREFER_MAP_WHEN_PATCHES_OVER_INDICES_LESS_THAN
307 }
308
309 pub fn take(&self, take_indices: &dyn Array) -> VortexResult<Option<Self>> {
311 if take_indices.is_empty() {
312 return Ok(None);
313 }
314 let take_indices = take_indices.to_primitive()?;
315 if self.is_map_faster_than_search(&take_indices) {
316 self.take_map(take_indices)
317 } else {
318 self.take_search(take_indices)
319 }
320 }
321
322 pub fn take_search(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
323 let new_length = take_indices.len();
324
325 let Some((new_indices, values_indices)) = match_each_integer_ptype!(take_indices.ptype(), |$I| {
326 take_search::<$I>(self.indices(), take_indices, self.offset())?
327 }) else {
328 return Ok(None);
329 };
330
331 Ok(Some(Self::new(
332 new_length,
333 0,
334 new_indices,
335 take(self.values(), &values_indices)?,
336 )))
337 }
338
339 pub fn take_map(&self, take_indices: PrimitiveArray) -> VortexResult<Option<Self>> {
340 let indices = self.indices.to_primitive()?;
341 let new_length = take_indices.len();
342
343 let Some((new_sparse_indices, value_indices)) = match_each_integer_ptype!(self.indices_ptype(), |$INDICES| {
344 match_each_integer_ptype!(take_indices.ptype(), |$TAKE_INDICES| {
345 take_map::<_, $TAKE_INDICES>(indices.as_slice::<$INDICES>(), take_indices, self.offset(), self.min_index()?, self.max_index()?)?
346 })
347 }) else {
348 return Ok(None);
349 };
350
351 Ok(Some(Patches::new(
352 new_length,
353 0,
354 new_sparse_indices,
355 take(self.values(), &value_indices)?,
356 )))
357 }
358
359 pub fn map_values<F>(self, f: F) -> VortexResult<Self>
360 where
361 F: FnOnce(ArrayRef) -> VortexResult<ArrayRef>,
362 {
363 let values = f(self.values)?;
364 if self.indices.len() != values.len() {
365 vortex_bail!(
366 "map_values must preserve length: expected {} received {}",
367 self.indices.len(),
368 values.len()
369 )
370 }
371 Ok(Self::new(self.array_len, self.offset, self.indices, values))
372 }
373}
374
375fn take_search<T: NativePType + TryFrom<usize>>(
376 indices: &dyn Array,
377 take_indices: PrimitiveArray,
378 indices_offset: usize,
379) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
380where
381 usize: TryFrom<T>,
382 VortexError: From<<usize as TryFrom<T>>::Error>,
383{
384 let take_indices_validity = take_indices.validity();
385 let take_indices = take_indices
386 .as_slice::<T>()
387 .iter()
388 .copied()
389 .map(usize::try_from)
390 .map_ok(|idx| idx + indices_offset)
391 .collect::<Result<Vec<_>, _>>()?;
392
393 let (values_indices, new_indices): (BufferMut<u64>, BufferMut<u64>) =
394 search_sorted_usize_many(indices, &take_indices, SearchSortedSide::Left)?
395 .iter()
396 .enumerate()
397 .filter_map(|(idx_in_take, search_result)| {
398 search_result
399 .to_found()
400 .map(|patch_idx| (patch_idx as u64, idx_in_take as u64))
401 })
402 .unzip();
403
404 if new_indices.is_empty() {
405 return Ok(None);
406 }
407
408 let new_indices = new_indices.into_array();
409 let values_validity = take_indices_validity.take(&new_indices)?;
410 Ok(Some((
411 new_indices,
412 PrimitiveArray::new(values_indices, values_validity).into_array(),
413 )))
414}
415
416fn take_map<I: NativePType + Hash + Eq + TryFrom<usize>, T: NativePType>(
417 indices: &[I],
418 take_indices: PrimitiveArray,
419 indices_offset: usize,
420 min_index: usize,
421 max_index: usize,
422) -> VortexResult<Option<(ArrayRef, ArrayRef)>>
423where
424 usize: TryFrom<T>,
425 VortexError: From<<I as TryFrom<usize>>::Error>,
426{
427 let take_indices_validity = take_indices.validity();
428 let take_indices = take_indices.as_slice::<T>();
429 let offset_i = I::try_from(indices_offset)?;
430
431 let sparse_index_to_value_index: HashMap<I, usize> = indices
432 .iter()
433 .copied()
434 .map(|idx| idx - offset_i)
435 .enumerate()
436 .map(|(value_index, sparse_index)| (sparse_index, value_index))
437 .collect();
438 let (new_sparse_indices, value_indices): (BufferMut<u64>, BufferMut<u64>) = take_indices
439 .iter()
440 .copied()
441 .map(usize::try_from)
442 .process_results(|iter| {
443 iter.enumerate()
444 .filter(|(_, ti)| *ti >= min_index && *ti <= max_index)
445 .filter_map(|(new_sparse_index, take_sparse_index)| {
446 sparse_index_to_value_index
447 .get(
448 &I::try_from(take_sparse_index)
449 .vortex_expect("take_sparse_index is between min and max index"),
450 )
451 .map(|value_index| (new_sparse_index as u64, *value_index as u64))
452 })
453 .unzip()
454 })
455 .map_err(|_| vortex_err!("Failed to convert index to usize"))?;
456
457 if new_sparse_indices.is_empty() {
458 return Ok(None);
459 }
460
461 let new_sparse_indices = new_sparse_indices.into_array();
462 let values_validity = take_indices_validity.take(&new_sparse_indices)?;
463 Ok(Some((
464 new_sparse_indices,
465 PrimitiveArray::new(value_indices, values_validity).into_array(),
466 )))
467}
468
469fn filter_patches_with_mask<T: ToPrimitive + Copy + Ord>(
475 patch_indices: &[T],
476 offset: usize,
477 patch_values: &dyn Array,
478 mask_indices: &[usize],
479) -> VortexResult<Option<Patches>> {
480 let true_count = mask_indices.len();
481 let mut new_patch_indices = BufferMut::<u64>::with_capacity(true_count);
482 let mut new_mask_indices = Vec::with_capacity(true_count);
483
484 const STRIDE: usize = 4;
488
489 let mut mask_idx = 0usize;
490 let mut true_idx = 0usize;
491
492 while mask_idx < patch_indices.len() && true_idx < true_count {
493 if (mask_idx + STRIDE) < patch_indices.len() && (true_idx + STRIDE) < mask_indices.len() {
500 let left_min = patch_indices[mask_idx].to_usize().vortex_expect("left_min") - offset;
502 let left_max = patch_indices[mask_idx + STRIDE]
503 .to_usize()
504 .vortex_expect("left_max")
505 - offset;
506 let right_min = mask_indices[true_idx];
507 let right_max = mask_indices[true_idx + STRIDE];
508
509 if left_min > right_max {
510 true_idx += STRIDE;
512 continue;
513 } else if right_min > left_max {
514 mask_idx += STRIDE;
515 continue;
516 } else {
517 }
519 }
520
521 let left = patch_indices[mask_idx].to_usize().vortex_expect("left") - offset;
524 let right = mask_indices[true_idx];
525
526 match left.cmp(&right) {
527 Ordering::Less => {
528 mask_idx += 1;
529 }
530 Ordering::Greater => {
531 true_idx += 1;
532 }
533 Ordering::Equal => {
534 new_mask_indices.push(mask_idx);
536 new_patch_indices.push(true_idx as u64);
537
538 mask_idx += 1;
539 true_idx += 1;
540 }
541 }
542 }
543
544 if new_mask_indices.is_empty() {
545 return Ok(None);
546 }
547
548 let new_patch_indices = new_patch_indices.into_array();
549 let new_patch_values = filter(
550 patch_values,
551 &Mask::from_indices(patch_values.len(), new_mask_indices),
552 )?;
553
554 Ok(Some(Patches::new(
555 true_count,
556 0,
557 new_patch_indices,
558 new_patch_values,
559 )))
560}
561
562#[cfg(test)]
563mod test {
564 use rstest::{fixture, rstest};
565 use vortex_buffer::buffer;
566 use vortex_mask::Mask;
567
568 use crate::array::Array;
569 use crate::arrays::PrimitiveArray;
570 use crate::compute::{SearchResult, SearchSortedSide};
571 use crate::patches::Patches;
572 use crate::validity::Validity;
573 use crate::{IntoArray, ToCanonical};
574
575 #[test]
576 fn test_filter() {
577 let patches = Patches::new(
578 100,
579 0,
580 buffer![10u32, 11, 20].into_array(),
581 buffer![100, 110, 200].into_array(),
582 );
583
584 let filtered = patches
585 .filter(&Mask::from_indices(100, vec![10, 20, 30]))
586 .unwrap()
587 .unwrap();
588
589 let indices = filtered.indices().to_primitive().unwrap();
590 let values = filtered.values().to_primitive().unwrap();
591 assert_eq!(indices.as_slice::<u64>(), &[0, 1]);
592 assert_eq!(values.as_slice::<i32>(), &[100, 200]);
593 }
594
595 #[fixture]
596 fn patches() -> Patches {
597 Patches::new(
598 20,
599 0,
600 buffer![2u64, 9, 15].into_array(),
601 PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
602 )
603 }
604
605 #[rstest]
606 fn search_larger_than(patches: Patches) {
607 let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
608 assert_eq!(res, SearchResult::NotFound(16));
609 }
610
611 #[rstest]
612 fn search_less_than(patches: Patches) {
613 let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
614 assert_eq!(res, SearchResult::NotFound(2));
615 }
616
617 #[rstest]
618 fn search_found(patches: Patches) {
619 let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
620 assert_eq!(res, SearchResult::Found(9));
621 }
622
623 #[rstest]
624 fn search_not_found_right(patches: Patches) {
625 let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
626 assert_eq!(res, SearchResult::NotFound(16));
627 }
628
629 #[rstest]
630 fn search_sliced(patches: Patches) {
631 let sliced = patches.slice(7, 20).unwrap().unwrap();
632 assert_eq!(
633 sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
634 SearchResult::NotFound(2)
635 );
636 }
637
638 #[test]
639 fn search_right() {
640 let patches = Patches::new(
641 6,
642 0,
643 buffer![0u8, 1, 4, 5].into_array(),
644 buffer![-128i8, -98, 8, 50].into_array(),
645 );
646
647 assert_eq!(
648 patches.search_sorted(-98, SearchSortedSide::Right).unwrap(),
649 SearchResult::Found(2)
650 );
651 assert_eq!(
652 patches.search_sorted(50, SearchSortedSide::Right).unwrap(),
653 SearchResult::Found(6),
654 );
655 assert_eq!(
656 patches.search_sorted(7, SearchSortedSide::Right).unwrap(),
657 SearchResult::NotFound(2),
658 );
659 assert_eq!(
660 patches.search_sorted(51, SearchSortedSide::Right).unwrap(),
661 SearchResult::NotFound(6)
662 );
663 }
664
665 #[test]
666 fn search_left() {
667 let patches = Patches::new(
668 20,
669 0,
670 buffer![0u64, 1, 17, 18, 19].into_array(),
671 buffer![11i32, 22, 33, 44, 55].into_array(),
672 );
673 assert_eq!(
674 patches.search_sorted(30, SearchSortedSide::Left).unwrap(),
675 SearchResult::NotFound(2)
676 );
677 assert_eq!(
678 patches.search_sorted(54, SearchSortedSide::Left).unwrap(),
679 SearchResult::NotFound(19)
680 );
681 }
682
683 #[rstest]
684 fn take_wit_nulls(patches: Patches) {
685 let taken = patches
686 .take(
687 &PrimitiveArray::new(buffer![9, 0], Validity::from_iter(vec![true, false]))
688 .into_array(),
689 )
690 .unwrap()
691 .unwrap();
692 let primitive_values = taken.values().to_primitive().unwrap();
693 assert_eq!(taken.array_len(), 2);
694 assert_eq!(primitive_values.as_slice::<i32>(), [44]);
695 assert_eq!(
696 primitive_values.validity_mask().unwrap(),
697 Mask::from_iter(vec![true])
698 );
699 }
700
701 #[test]
702 fn test_slice() {
703 let values = buffer![15_u32, 135, 13531, 42].into_array();
704 let indices = buffer![10_u64, 11, 50, 100].into_array();
705
706 let patches = Patches::new(101, 0, indices, values);
707
708 let sliced = patches.slice(15, 100).unwrap().unwrap();
709 assert_eq!(sliced.array_len(), 100 - 15);
710 let primitive = sliced.values().to_primitive().unwrap();
711
712 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
713 }
714
715 #[test]
716 fn doubly_sliced() {
717 let values = buffer![15_u32, 135, 13531, 42].into_array();
718 let indices = buffer![10_u64, 11, 50, 100].into_array();
719
720 let patches = Patches::new(101, 0, indices, values);
721
722 let sliced = patches.slice(15, 100).unwrap().unwrap();
723 assert_eq!(sliced.array_len(), 100 - 15);
724 let primitive = sliced.values().to_primitive().unwrap();
725
726 assert_eq!(primitive.as_slice::<u32>(), &[13531]);
727
728 let doubly_sliced = sliced.slice(35, 36).unwrap().unwrap();
729 let primitive_doubly_sliced = doubly_sliced.values().to_primitive().unwrap();
730
731 assert_eq!(primitive_doubly_sliced.as_slice::<u32>(), &[13531]);
732 }
733}