1use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
2use vortex_scalar::Scalar;
3
4use crate::arrays::ConstantArray;
5use crate::builders::ArrayBuilder;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProviderExt, StatsSet};
8use crate::{Array, ArrayRef, IntoArray};
9
10pub trait TakeFn<A> {
11 fn take(&self, array: A, indices: &dyn Array) -> VortexResult<ArrayRef>;
18
19 fn take_into(
22 &self,
23 array: A,
24 indices: &dyn Array,
25 builder: &mut dyn ArrayBuilder,
26 ) -> VortexResult<()> {
27 builder.extend_from_array(&self.take(array, indices)?)
28 }
29}
30
31impl<E: Encoding> TakeFn<&dyn Array> for E
32where
33 E: for<'a> TakeFn<&'a E::Array>,
34{
35 fn take(&self, array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 let array_ref = array
37 .as_any()
38 .downcast_ref::<E::Array>()
39 .vortex_expect("Failed to downcast array");
40 TakeFn::take(self, array_ref, indices)
41 }
42
43 fn take_into(
44 &self,
45 array: &dyn Array,
46 indices: &dyn Array,
47 builder: &mut dyn ArrayBuilder,
48 ) -> VortexResult<()> {
49 let array_ref = array
50 .as_any()
51 .downcast_ref::<E::Array>()
52 .vortex_expect("Failed to downcast array");
53 TakeFn::take_into(self, array_ref, indices, builder)
54 }
55}
56
57pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
58 if indices.all_invalid()? {
63 return Ok(
64 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
65 .into_array(),
66 );
67 }
68
69 if !indices.dtype().is_int() {
70 vortex_bail!(
71 "Take indices must be an integer type, got {}",
72 indices.dtype()
73 );
74 }
75
76 let derived_stats = (!array.is_constant()).then(|| derive_take_stats(array));
79
80 let taken = take_impl(array, indices)?;
81
82 if let Some(derived_stats) = derived_stats {
83 let mut stats = taken.statistics().to_owned();
84 stats.combine_sets(&derived_stats, array.dtype())?;
85 for (stat, val) in stats.into_iter() {
86 taken.statistics().set(stat, val)
87 }
88 }
89
90 assert_eq!(
91 taken.len(),
92 indices.len(),
93 "Take length mismatch {}",
94 array.encoding()
95 );
96 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
98 assert_eq!(
99 taken.dtype(),
100 &array.dtype().with_nullability(expected_nullability),
101 "Take result ({}) should be nullable if either the indices ({}) or the array ({}) are nullable. ({})",
102 taken.dtype(),
103 indices.dtype().nullability().verbose_display(),
104 array.dtype().nullability().verbose_display(),
105 array.encoding(),
106 );
107
108 Ok(taken)
109}
110
111pub fn take_into(
112 array: &dyn Array,
113 indices: &dyn Array,
114 builder: &mut dyn ArrayBuilder,
115) -> VortexResult<()> {
116 if indices.all_invalid()? {
117 builder.append_nulls(indices.len());
118 return Ok(());
119 }
120
121 if array.is_empty() && !indices.is_empty() {
122 vortex_bail!("Cannot take_into from an empty array");
123 }
124
125 let expected_nullability = indices.dtype().nullability() | array.dtype().nullability();
127 assert_eq!(
128 builder.dtype(),
129 &array.dtype().with_nullability(expected_nullability),
130 "Take_into result ({}) should be nullable if, and only if, either the indices ({}) or the array ({}) are nullable. ({})",
131 builder.dtype(),
132 indices.dtype().nullability().verbose_display(),
133 array.dtype().nullability().verbose_display(),
134 array.encoding(),
135 );
136
137 if !indices.dtype().is_int() {
138 vortex_bail!(
139 "Take indices must be an integer type, got {}",
140 indices.dtype()
141 );
142 }
143
144 let before_len = builder.len();
145
146 take_into_impl(array, indices, builder)?;
149
150 let after_len = builder.len();
151
152 assert_eq!(
153 after_len - before_len,
154 indices.len(),
155 "Take_into length mismatch {}",
156 array.encoding()
157 );
158
159 Ok(())
160}
161
162fn derive_take_stats(arr: &dyn Array) -> StatsSet {
163 let stats = arr.statistics().to_owned();
164
165 let is_constant = stats.get_as::<bool>(Stat::IsConstant);
166
167 let mut stats = stats.keep_inexact_stats(&[
168 Stat::Min,
170 Stat::Max,
171 ]);
172
173 if is_constant == Some(Precision::Exact(true)) {
174 stats.set(Stat::IsConstant, Precision::exact(true));
176 }
177
178 stats
179}
180
181fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
182 if let Some(take_from_fn) = indices.vtable().take_from_fn() {
184 if let Some(arr) = take_from_fn.take_from(indices, array)? {
185 return Ok(arr);
186 }
187 }
188
189 if let Some(take_fn) = array.vtable().take_fn() {
192 return take_fn.take(array, indices);
193 }
194
195 log::debug!("No take implementation found for {}", array.encoding());
197 let canonical = array.to_canonical()?.into_array();
198 let vtable = canonical.vtable();
199 let canonical_take_fn = vtable
200 .take_fn()
201 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
202
203 canonical_take_fn.take(&canonical, indices)
204}
205
206fn take_into_impl(
207 array: &dyn Array,
208 indices: &dyn Array,
209 builder: &mut dyn ArrayBuilder,
210) -> VortexResult<()> {
211 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
212 let result_dtype = array.dtype().with_nullability(result_nullability);
213 if &result_dtype != builder.dtype() {
214 vortex_bail!(
215 "TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
216 array.encoding(),
217 builder.dtype(),
218 result_dtype,
219 );
220 }
221 if let Some(take_fn) = array.vtable().take_fn() {
222 return take_fn.take_into(array, indices, builder);
223 }
224
225 log::debug!("No take_into implementation found for {}", array.encoding());
227 let canonical = array.to_canonical()?.into_array();
228 let vtable = canonical.vtable();
229 let canonical_take_fn = vtable
230 .take_fn()
231 .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding()))?;
232
233 canonical_take_fn.take_into(&canonical, indices, builder)
234}