vortex_array/compute/
take.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};

use crate::encoding::Encoding;
use crate::stats::{ArrayStatistics, Stat};
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical};

pub trait TakeFn<Array> {
    /// Create a new array by taking the values from the `array` at the
    /// given `indices`.
    ///
    /// # Panics
    ///
    /// Using `indices` that are invalid for the given `array` will cause a panic.
    fn take(&self, array: &Array, indices: &ArrayData) -> VortexResult<ArrayData>;

    /// Create a new array by taking the values from the `array` at the
    /// given `indices`.
    ///
    /// # Safety
    ///
    /// This take variant will not perform bounds checking on indices, so it is the caller's
    /// responsibility to ensure that the `indices` are all valid for the provided `array`.
    /// Failure to do so could result in out of bounds memory access or UB.
    unsafe fn take_unchecked(&self, array: &Array, indices: &ArrayData) -> VortexResult<ArrayData> {
        self.take(array, indices)
    }
}

impl<E: Encoding> TakeFn<ArrayData> for E
where
    E: TakeFn<E::Array>,
    for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
    fn take(&self, array: &ArrayData, indices: &ArrayData) -> VortexResult<ArrayData> {
        let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
        TakeFn::take(encoding, array_ref, indices)
    }
}

pub fn take(
    array: impl AsRef<ArrayData>,
    indices: impl AsRef<ArrayData>,
) -> VortexResult<ArrayData> {
    // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
    //  the filter function since they're typically optimised for this case.
    // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
    //  such that canonicalize does less work.

    let array = array.as_ref();
    let indices = indices.as_ref();

    if !indices.dtype().is_int() || indices.dtype().is_nullable() {
        vortex_bail!(
            "Take indices must be a non-nullable integer type, got {}",
            indices.dtype()
        );
    }

    // If the indices are all within bounds, we can skip bounds checking.
    let checked_indices = indices
        .statistics()
        .get_as::<usize>(Stat::Max)
        .is_some_and(|max| max < array.len());

    let taken = take_impl(array, indices, checked_indices)?;

    debug_assert_eq!(
        taken.len(),
        indices.len(),
        "Take length mismatch {}",
        array.encoding().id()
    );
    debug_assert_eq!(
        array.dtype(),
        taken.dtype(),
        "Take dtype mismatch {}",
        array.encoding().id()
    );

    Ok(taken)
}

fn take_impl(
    array: &ArrayData,
    indices: &ArrayData,
    checked_indices: bool,
) -> VortexResult<ArrayData> {
    // If TakeFn defined for the encoding, delegate to TakeFn.
    // If we know from stats that indices are all valid, we can avoid all bounds checks.
    if let Some(take_fn) = array.encoding().take_fn() {
        let result = if checked_indices {
            // SAFETY: indices are all inbounds per stats.
            // TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad.
            unsafe { take_fn.take_unchecked(array, indices) }
        } else {
            take_fn.take(array, indices)
        }?;
        if array.dtype() != result.dtype() {
            vortex_bail!(
                "TakeFn {} changed array dtype from {} to {}",
                array.encoding().id(),
                array.dtype(),
                result.dtype()
            );
        }
        return Ok(result);
    }

    // Otherwise, flatten and try again.
    log::debug!("No take implementation found for {}", array.encoding().id());
    let canonical = array.clone().into_canonical()?.into_array();
    let canonical_take_fn = canonical
        .encoding()
        .take_fn()
        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding().id()))?;

    if checked_indices {
        // SAFETY: indices are known to be in-bound from stats
        unsafe { canonical_take_fn.take_unchecked(&canonical, indices) }
    } else {
        canonical_take_fn.take(&canonical, indices)
    }
}