1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11 fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19 fn take_into(
22 &self,
23 array: A,
24 indices: &dyn Array,
25 builder: &mut dyn ArrayBuilder,
26 ) -> VortexResult<()> {
27 builder.extend_from_array(&self.take(array, indices)?)
28 }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33 E: for<'a> TakeFn<&'a E::Array>,
34{
35 fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 let array_ref = array
37 .as_any()
38 .downcast_ref::<E::Array>()
39 .vortex_expect("Failed to downcast array");
40 TakeFn::take(self, array_ref, indices)
41 }
42
43 fn take_into(
44 &self,
45 array: &dyn Array,
46 indices: &dyn Array,
47 builder: &mut dyn ArrayBuilder,
48 ) -> VortexResult<()> {
49 let array_ref = array
50 .as_any()
51 .downcast_ref::<E::Array>()
52 .vortex_expect("Failed to downcast array");
53 TakeFn::take_into(self, array_ref, indices, builder)
54 }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58 if indices.all_invalid()? {
63 return Ok(
64 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65 .into_array(),
66 );
67 }
68
69 if !indices.dtype().is_int() {
70 vortex_bail!(
71 "Take indices must be an integer type, got {}",
72 indices.dtype()
73 );
74 }
75
76 let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80 let taken = take_impl(array, indices)?;
81
82 if let Some(derived_stats) = derived_stats {
83 let mut stats = taken.statistics().to_owned();
84 stats.combine_sets(&derived_stats, array.dtype())?;
85 for (stat, val) in stats.into_iter() {
86 taken.statistics().set(stat, val)
87 }
88 }
89
90 debug_assert_eq!(
91 taken.len(),
92 indices.len(),
93 "Take length mismatch {}",
94 array.encoding()
95 );
96 #[cfg(debug_assertions)]
97 {
98 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
100 assert_eq!(
101 taken.dtype(),
102 &array.dtype().with_nullability(expected_nullability),
103 "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
104 taken.dtype(),
105 indices.dtype().nullability().verbose_display(),
106 array.dtype().nullability().verbose_display(),
107 array.encoding(),
108 );
109 }
110
111 Ok(taken)
112}
113
114pub fn take_into(
115 array: &dyn Array,
116 indices: &dyn Array,
117 builder: &mut dyn ArrayBuilder,
118) -> VortexResult<()> {
119 if array.is_empty() && !indices.is_empty() {
120 vortex_bail!("Cannot take_into from an empty array");
121 }
122
123 #[cfg(debug_assertions)]
124 {
125 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
127 assert_eq!(
128 builder.dtype(),
129 &array.dtype().with_nullability(expected_nullability),
130 "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
131 builder.dtype(),
132 indices.dtype().nullability().verbose_display(),
133 array.dtype().nullability().verbose_display(),
134 array.encoding(),
135 );
136 }
137
138 if !indices.dtype().is_int() {
139 vortex_bail!(
140 "Take indices must be an integer type, got {}",
141 indices.dtype()
142 );
143 }
144
145 let before_len = builder.len();
146
147 take_into_impl(array, indices, builder)?;
150
151 let after_len = builder.len();
152
153 debug_assert_eq!(
154 after_len - before_len,
155 indices.len(),
156 "Take_into length mismatch {}",
157 array.encoding()
158 );
159
160 Ok(())
161}
162
163fn derive_take_stats(arr: &dyn Array) -> StatsSet {
164 let stats = arr.statistics().to_owned();
165
166 let is_constant = stats.get_as::<bool>(Stat::IsConstant);
167
168 let mut stats = stats.keep_inexact_stats(&[
169 Stat::Min,
171 Stat::Max,
172 ]);
173
174 if is_constant == Some(Precision::Exact(true)) {
175 stats.set(Stat::IsConstant, Precision::exact(true));
177 }
178
179 stats
180}
181
182fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
183 if let Some(take_from_fn) = indices.vtable().take_from_fn() {
185 if let Some(arr) = take_from_fn.take_from(indices, array)? {
186 return Ok(arr);
187 }
188 }
189
190 if let Some(take_fn) = array.vtable().take_fn() {
193 return take_fn.take(array, indices);
194 }
195
196 log::debug!("No take implementation found for {}", array.encoding());
198 let canonical = array.to_canonical()?.into_array();
199 let vtable = canonical.vtable();
200 let canonical_take_fn = vtable
201 .take_fn()
202 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
203
204 canonical_take_fn.take(&canonical, indices)
205}
206
207fn take_into_impl(
208 array: &dyn Array,
209 indices: &dyn Array,
210 builder: &mut dyn ArrayBuilder,
211) -> VortexResult<()> {
212 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
213 let result_dtype = array.dtype().with_nullability(result_nullability);
214 if &result_dtype != builder.dtype() {
215 vortex_bail!(
216 "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
217 array.encoding(),
218 builder.dtype(),
219 result_dtype,
220 );
221 }
222 if let Some(take_fn) = array.vtable().take_fn() {
223 return take_fn.take_into(array, indices, builder);
224 }
225
226 log::debug!("No take_into implementation found for {}", array.encoding());
228 let canonical = array.to_canonical()?.into_array();
229 let vtable = canonical.vtable();
230 let canonical_take_fn = vtable
231 .take_fn()
232 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
233
234 canonical_take_fn.take_into(&canonical, indices, builder)
235}