1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11 fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19 fn take_into(
22 &self,
23 array: A,
24 indices: &dyn Array,
25 builder: &mut dyn ArrayBuilder,
26 ) -> VortexResult<()> {
27 builder.extend_from_array(&self.take(array, indices)?)
28 }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33 E: for<'a> TakeFn<&'a E::Array>,
34{
35 fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 let array_ref = array
37 .as_any()
38 .downcast_ref::<E::Array>()
39 .vortex_expect("Failed to downcast array");
40 TakeFn::take(self, array_ref, indices)
41 }
42
43 fn take_into(
44 &self,
45 array: &dyn Array,
46 indices: &dyn Array,
47 builder: &mut dyn ArrayBuilder,
48 ) -> VortexResult<()> {
49 let array_ref = array
50 .as_any()
51 .downcast_ref::<E::Array>()
52 .vortex_expect("Failed to downcast array");
53 TakeFn::take_into(self, array_ref, indices, builder)
54 }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58 if indices.all_invalid()? {
63 return Ok(
64 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65 .into_array(),
66 );
67 }
68
69 if !indices.dtype().is_int() {
70 vortex_bail!(
71 "Take indices must be an integer type, got {}",
72 indices.dtype()
73 );
74 }
75
76 let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80 let taken = take_impl(array, indices)?;
81
82 if let Some(derived_stats) = derived_stats {
83 let mut stats = taken.statistics().to_owned();
84 stats.combine_sets(&derived_stats, array.dtype())?;
85 for (stat, val) in stats.into_iter() {
86 taken.statistics().set(stat, val)
87 }
88 }
89
90 debug_assert_eq!(
91 taken.len(),
92 indices.len(),
93 "Take length mismatch {}",
94 array.encoding()
95 );
96 #[cfg(debug_assertions)]
97 {
98 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
100 assert_eq!(
101 taken.dtype(),
102 &array.dtype().with_nullability(expected_nullability),
103 "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
104 taken.dtype(),
105 indices.dtype().nullability().verbose_display(),
106 array.dtype().nullability().verbose_display(),
107 array.encoding(),
108 );
109 }
110
111 Ok(taken)
112}
113
114pub fn take_into(
115 array: &dyn Array,
116 indices: &dyn Array,
117 builder: &mut dyn ArrayBuilder,
118) -> VortexResult<()> {
119 if indices.all_invalid()? {
120 builder.append_nulls(indices.len());
121 return Ok(());
122 }
123
124 if array.is_empty() && !indices.is_empty() {
125 vortex_bail!("Cannot take_into from an empty array");
126 }
127
128 #[cfg(debug_assertions)]
129 {
130 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
132 assert_eq!(
133 builder.dtype(),
134 &array.dtype().with_nullability(expected_nullability),
135 "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
136 builder.dtype(),
137 indices.dtype().nullability().verbose_display(),
138 array.dtype().nullability().verbose_display(),
139 array.encoding(),
140 );
141 }
142
143 if !indices.dtype().is_int() {
144 vortex_bail!(
145 "Take indices must be an integer type, got {}",
146 indices.dtype()
147 );
148 }
149
150 let before_len = builder.len();
151
152 take_into_impl(array, indices, builder)?;
155
156 let after_len = builder.len();
157
158 debug_assert_eq!(
159 after_len - before_len,
160 indices.len(),
161 "Take_into length mismatch {}",
162 array.encoding()
163 );
164
165 Ok(())
166}
167
168fn derive_take_stats(arr: &dyn Array) -> StatsSet {
169 let stats = arr.statistics().to_owned();
170
171 let is_constant = stats.get_as::<bool>(Stat::IsConstant);
172
173 let mut stats = stats.keep_inexact_stats(&[
174 Stat::Min,
176 Stat::Max,
177 ]);
178
179 if is_constant == Some(Precision::Exact(true)) {
180 stats.set(Stat::IsConstant, Precision::exact(true));
182 }
183
184 stats
185}
186
187fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
188 if let Some(take_from_fn) = indices.vtable().take_from_fn() {
190 if let Some(arr) = take_from_fn.take_from(indices, array)? {
191 return Ok(arr);
192 }
193 }
194
195 if let Some(take_fn) = array.vtable().take_fn() {
198 return take_fn.take(array, indices);
199 }
200
201 log::debug!("No take implementation found for {}", array.encoding());
203 let canonical = array.to_canonical()?.into_array();
204 let vtable = canonical.vtable();
205 let canonical_take_fn = vtable
206 .take_fn()
207 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
208
209 canonical_take_fn.take(&canonical, indices)
210}
211
212fn take_into_impl(
213 array: &dyn Array,
214 indices: &dyn Array,
215 builder: &mut dyn ArrayBuilder,
216) -> VortexResult<()> {
217 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
218 let result_dtype = array.dtype().with_nullability(result_nullability);
219 if &result_dtype != builder.dtype() {
220 vortex_bail!(
221 "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
222 array.encoding(),
223 builder.dtype(),
224 result_dtype,
225 );
226 }
227 if let Some(take_fn) = array.vtable().take_fn() {
228 return take_fn.take_into(array, indices, builder);
229 }
230
231 log::debug!("No take_into implementation found for {}", array.encoding());
233 let canonical = array.to_canonical()?.into_array();
234 let vtable = canonical.vtable();
235 let canonical_take_fn = vtable
236 .take_fn()
237 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
238
239 canonical_take_fn.take_into(&canonical, indices, builder)
240}