1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11    fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19    fn take_into(
22        &self,
23        array: A,
24        indices: &dyn Array,
25        builder: &mut dyn ArrayBuilder,
26    ) -> VortexResult<()> {
27        builder.extend_from_array(&self.take(array, indices)?)
28    }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33    E: for<'a> TakeFn<&'a E::Array>,
34{
35    fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36        let array_ref = array
37            .as_any()
38            .downcast_ref::<E::Array>()
39            .vortex_expect("Failed to downcast array");
40        TakeFn::take(self, array_ref, indices)
41    }
42
43    fn take_into(
44        &self,
45        array: &dyn Array,
46        indices: &dyn Array,
47        builder: &mut dyn ArrayBuilder,
48    ) -> VortexResult<()> {
49        let array_ref = array
50            .as_any()
51            .downcast_ref::<E::Array>()
52            .vortex_expect("Failed to downcast array");
53        TakeFn::take_into(self, array_ref, indices, builder)
54    }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58    if indices.all_invalid()? {
63        return Ok(
64            ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65                .into_array(),
66        );
67    }
68
69    if !indices.dtype().is_int() {
70        vortex_bail!(
71            "Take indices must be an integer type, got {}",
72            indices.dtype()
73        );
74    }
75
76    let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80    let taken = take_impl(array, indices)?;
81
82    if let Some(derived_stats) = derived_stats {
83        let mut stats = taken.statistics().to_owned();
84        stats.combine_sets(&derived_stats, array.dtype())?;
85        for (stat, val) in stats.into_iter() {
86            taken.statistics().set(stat, val)
87        }
88    }
89
90    assert_eq!(
91        taken.len(),
92        indices.len(),
93        "Take length mismatch {}",
94        array.encoding()
95    );
96    let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
98    assert_eq!(
99        taken.dtype(),
100        &array.dtype().with_nullability(expected_nullability),
101        "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
102        taken.dtype(),
103        indices.dtype().nullability().verbose_display(),
104        array.dtype().nullability().verbose_display(),
105        array.encoding(),
106    );
107
108    Ok(taken)
109}
110
111pub fn take_into(
112    array: &dyn Array,
113    indices: &dyn Array,
114    builder: &mut dyn ArrayBuilder,
115) -> VortexResult<()> {
116    if indices.all_invalid()? {
117        builder.append_nulls(indices.len());
118        return Ok(());
119    }
120
121    if array.is_empty() && !indices.is_empty() {
122        vortex_bail!("Cannot take_into from an empty array");
123    }
124
125    let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
127    assert_eq!(
128        builder.dtype(),
129        &array.dtype().with_nullability(expected_nullability),
130        "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
131        builder.dtype(),
132        indices.dtype().nullability().verbose_display(),
133        array.dtype().nullability().verbose_display(),
134        array.encoding(),
135    );
136
137    if !indices.dtype().is_int() {
138        vortex_bail!(
139            "Take indices must be an integer type, got {}",
140            indices.dtype()
141        );
142    }
143
144    let before_len = builder.len();
145
146    take_into_impl(array, indices, builder)?;
149
150    let after_len = builder.len();
151
152    assert_eq!(
153        after_len - before_len,
154        indices.len(),
155        "Take_into length mismatch {}",
156        array.encoding()
157    );
158
159    Ok(())
160}
161
162fn derive_take_stats(arr: &dyn Array) -> StatsSet {
163    let stats = arr.statistics().to_owned();
164
165    let is_constant = stats.get_as::<bool>(Stat::IsConstant);
166
167    let mut stats = stats.keep_inexact_stats(&[
168        Stat::Min,
170        Stat::Max,
171    ]);
172
173    if is_constant == Some(Precision::Exact(true)) {
174        stats.set(Stat::IsConstant, Precision::exact(true));
176    }
177
178    stats
179}
180
181fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
182    if let Some(take_from_fn) = indices.vtable().take_from_fn() {
184        if let Some(arr) = take_from_fn.take_from(indices, array)? {
185            return Ok(arr);
186        }
187    }
188
189    if let Some(take_fn) = array.vtable().take_fn() {
192        return take_fn.take(array, indices);
193    }
194
195    log::debug!("No take implementation found for {}", array.encoding());
197    let canonical = array.to_canonical()?.into_array();
198    let vtable = canonical.vtable();
199    let canonical_take_fn = vtable
200        .take_fn()
201        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
202
203    canonical_take_fn.take(&canonical, indices)
204}
205
206fn take_into_impl(
207    array: &dyn Array,
208    indices: &dyn Array,
209    builder: &mut dyn ArrayBuilder,
210) -> VortexResult<()> {
211    let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
212    let result_dtype = array.dtype().with_nullability(result_nullability);
213    if &result_dtype != builder.dtype() {
214        vortex_bail!(
215            "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
216            array.encoding(),
217            builder.dtype(),
218            result_dtype,
219        );
220    }
221    if let Some(take_fn) = array.vtable().take_fn() {
222        return take_fn.take_into(array, indices, builder);
223    }
224
225    log::debug!("No take_into implementation found for {}", array.encoding());
227    let canonical = array.to_canonical()?.into_array();
228    let vtable = canonical.vtable();
229    let canonical_take_fn = vtable
230        .take_fn()
231        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
232
233    canonical_take_fn.take_into(&canonical, indices, builder)
234}