vortex_array/compute/take.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use crate::encoding::Encoding;
use crate::stats::Stat;
use crate::{Array, IntoArray, IntoCanonical};
pub trait TakeFn<A> {
/// Create a new array by taking the values from the `array` at the
/// given `indices`.
///
/// # Panics
///
/// Using `indices` that are invalid for the given `array` will cause a panic.
fn take(&self, array: &A, indices: &Array) -> VortexResult<Array>;
/// Create a new array by taking the values from the `array` at the
/// given `indices`.
///
/// # Safety
///
/// This take variant will not perform bounds checking on indices, so it is the caller's
/// responsibility to ensure that the `indices` are all valid for the provided `array`.
/// Failure to do so could result in out of bounds memory access or UB.
unsafe fn take_unchecked(&self, array: &A, indices: &Array) -> VortexResult<Array> {
self.take(array, indices)
}
}
impl<E: Encoding> TakeFn<Array> for E
where
E: TakeFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a Array, Error = VortexError>,
{
fn take(&self, array: &Array, indices: &Array) -> VortexResult<Array> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
TakeFn::take(encoding, array_ref, indices)
}
}
pub fn take(array: impl AsRef<Array>, indices: impl AsRef<Array>) -> VortexResult<Array> {
// TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
// the filter function since they're typically optimised for this case.
// TODO(ngates): if indices min is quite high, we could slice self and offset the indices
// such that canonicalize does less work.
let array = array.as_ref();
let indices = indices.as_ref();
if !indices.dtype().is_int() || indices.dtype().is_nullable() {
vortex_bail!(
"Take indices must be a non-nullable integer type, got {}",
indices.dtype()
);
}
// If the indices are all within bounds, we can skip bounds checking.
let checked_indices = indices
.statistics()
.get_as::<usize>(Stat::Max)
.is_some_and(|max| max < array.len());
let taken = take_impl(array, indices, checked_indices)?;
debug_assert_eq!(
taken.len(),
indices.len(),
"Take length mismatch {}",
array.encoding()
);
debug_assert_eq!(
array.dtype(),
taken.dtype(),
"Take dtype mismatch {}",
array.encoding()
);
Ok(taken)
}
fn take_impl(array: &Array, indices: &Array, checked_indices: bool) -> VortexResult<Array> {
// If TakeFn defined for the encoding, delegate to TakeFn.
// If we know from stats that indices are all valid, we can avoid all bounds checks.
if let Some(take_fn) = array.vtable().take_fn() {
let result = if checked_indices {
// SAFETY: indices are all inbounds per stats.
// TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad.
unsafe { take_fn.take_unchecked(array, indices) }
} else {
take_fn.take(array, indices)
}?;
if array.dtype() != result.dtype() {
vortex_bail!(
"TakeFn {} changed array dtype from {} to {}",
array.encoding(),
array.dtype(),
result.dtype()
);
}
return Ok(result);
}
// Otherwise, flatten and try again.
log::debug!("No take implementation found for {}", array.encoding());
let canonical = array.clone().into_canonical()?.into_array();
let canonical_take_fn = canonical
.vtable()
.take_fn()
.ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
if checked_indices {
// SAFETY: indices are known to be in-bound from stats
unsafe { canonical_take_fn.take_unchecked(&canonical, indices) }
} else {
canonical_take_fn.take(&canonical, indices)
}
}