vortex_array/compute/
take.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
13use crate::stats::{Precision, Stat, StatsProvider, StatsProviderExt, StatsSet};
14use crate::vtable::VTable;
15use crate::{Array, ArrayRef, Canonical, IntoArray};
16
17static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
18 let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
19 for kernel in inventory::iter::<TakeKernelRef> {
20 compute.register_kernel(kernel.0.clone());
21 }
22 compute
23});
24
25pub(crate) fn warm_up_vtable() -> usize {
26 TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
27}
28
29pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
36 if indices.is_empty() {
37 return Ok(Canonical::empty(
38 &array
39 .dtype()
40 .union_nullability(indices.dtype().nullability()),
41 )
42 .into_array());
43 }
44
45 TAKE_FN
46 .invoke(&InvocationArgs {
47 inputs: &[array.into(), indices.into()],
48 options: &(),
49 })?
50 .unwrap_array()
51}
52
53#[doc(hidden)]
54pub struct Take;
55
56impl ComputeFnVTable for Take {
57 fn invoke(
58 &self,
59 args: &InvocationArgs,
60 kernels: &[ArcRef<dyn Kernel>],
61 ) -> VortexResult<Output> {
62 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
63
64 if indices.all_invalid() {
70 return Ok(ConstantArray::new(
71 Scalar::null(array.dtype().as_nullable()),
72 indices.len(),
73 )
74 .into_array()
75 .into());
76 }
77
78 let taken_array = take_impl(array, indices, kernels)?;
79
80 if !taken_array.is_constant() {
83 propagate_take_stats(array, &taken_array, indices)?;
84 }
85
86 Ok(taken_array.into())
87 }
88
89 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
90 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
91
92 if !indices.dtype().is_int() {
93 vortex_bail!(
94 "Take indices must be an integer type, got {}",
95 indices.dtype()
96 );
97 }
98
99 Ok(array
100 .dtype()
101 .union_nullability(indices.dtype().nullability()))
102 }
103
104 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
105 let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
106 Ok(indices.len())
107 }
108
109 fn is_elementwise(&self) -> bool {
110 false
111 }
112}
113
114fn propagate_take_stats(
115 source: &dyn Array,
116 target: &dyn Array,
117 indices: &dyn Array,
118) -> VortexResult<()> {
119 target.statistics().with_mut_typed_stats_set(|mut st| {
120 if indices.all_valid() {
121 let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
122 if is_constant == Some(Precision::Exact(true)) {
123 st.set(Stat::IsConstant, Precision::exact(true));
125 }
126 }
127 let inexact_min_max = [Stat::Min, Stat::Max]
128 .into_iter()
129 .filter_map(|stat| {
130 source
131 .statistics()
132 .get(stat)
133 .map(|v| (stat, v.map(|s| s.into_value()).into_inexact()))
134 })
135 .collect::<Vec<_>>();
136 st.combine_sets(
137 &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
138 )
139 })
140}
141
142fn take_impl(
143 array: &dyn Array,
144 indices: &dyn Array,
145 kernels: &[ArcRef<dyn Kernel>],
146) -> VortexResult<ArrayRef> {
147 let args = InvocationArgs {
148 inputs: &[array.into(), indices.into()],
149 options: &(),
150 };
151
152 for kernel in TAKE_FROM_FN.kernels() {
154 if let Some(output) = kernel.invoke(&args)? {
155 return output.unwrap_array();
156 }
157 }
158 if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
159 return output.unwrap_array();
160 }
161
162 for kernel in kernels {
164 if let Some(output) = kernel.invoke(&args)? {
165 return output.unwrap_array();
166 }
167 }
168 if let Some(output) = array.invoke(&TAKE_FN, &args)? {
169 return output.unwrap_array();
170 }
171
172 if !array.is_canonical() {
174 log::debug!("No take implementation found for {}", array.encoding_id());
175 let canonical = array.to_canonical();
176 return take(canonical.as_ref(), indices);
177 }
178
179 vortex_bail!("No take implementation found for {}", array.encoding_id());
180}
181
182struct TakeArgs<'a> {
183 array: &'a dyn Array,
184 indices: &'a dyn Array,
185}
186
187impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
188 type Error = VortexError;
189
190 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
191 if value.inputs.len() != 2 {
192 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
193 }
194 let array = value.inputs[0]
195 .array()
196 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
197 let indices = value.inputs[1]
198 .array()
199 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
200 Ok(Self { array, indices })
201 }
202}
203
204pub trait TakeKernel: VTable {
205 fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
212}
213
214pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
216inventory::collect!(TakeKernelRef);
217
218#[derive(Debug)]
219pub struct TakeKernelAdapter<V: VTable>(pub V);
220
221impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
222 pub const fn lift(&'static self) -> TakeKernelRef {
223 TakeKernelRef(ArcRef::new_ref(self))
224 }
225}
226
227impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
228 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
229 let inputs = TakeArgs::try_from(args)?;
230 let Some(array) = inputs.array.as_opt::<V>() else {
231 return Ok(None);
232 };
233 Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
234 }
235}
236
237static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
238 let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
239 for kernel in inventory::iter::<TakeFromKernelRef> {
240 compute.register_kernel(kernel.0.clone());
241 }
242 compute
243});
244
245pub struct TakeFrom;
246
247impl ComputeFnVTable for TakeFrom {
248 fn invoke(
249 &self,
250 _args: &InvocationArgs,
251 _kernels: &[ArcRef<dyn Kernel>],
252 ) -> VortexResult<Output> {
253 vortex_bail!(
254 "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
255 )
256 }
257
258 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
259 Take.return_dtype(args)
260 }
261
262 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
263 Take.return_len(args)
264 }
265
266 fn is_elementwise(&self) -> bool {
267 Take.is_elementwise()
268 }
269}
270
271pub trait TakeFromKernel: VTable {
272 fn take_from(&self, indices: &Self::Array, array: &dyn Array)
275 -> VortexResult<Option<ArrayRef>>;
276}
277
278pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
279inventory::collect!(TakeFromKernelRef);
280
281#[derive(Debug)]
282pub struct TakeFromKernelAdapter<V: VTable>(pub V);
283
284impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
285 pub const fn lift(&'static self) -> TakeFromKernelRef {
286 TakeFromKernelRef(ArcRef::new_ref(self))
287 }
288}
289
290impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
291 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
292 let inputs = TakeArgs::try_from(args)?;
293 let Some(indices) = inputs.indices.as_opt::<V>() else {
294 return Ok(None);
295 };
296 Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
297 }
298}