vortex_array/compute/
take.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexError;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_scalar::Scalar;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::IntoArray;
18use crate::arrays::ConstantArray;
19use crate::compute::ComputeFn;
20use crate::compute::ComputeFnVTable;
21use crate::compute::InvocationArgs;
22use crate::compute::Kernel;
23use crate::compute::Output;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::expr::stats::StatsProviderExt;
28use crate::stats::StatsSet;
29use crate::vtable::VTable;
30
31static TAKE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
32 let compute = ComputeFn::new("take".into(), ArcRef::new_ref(&Take));
33 for kernel in inventory::iter::<TakeKernelRef> {
34 compute.register_kernel(kernel.0.clone());
35 }
36 compute
37});
38
39pub(crate) fn warm_up_vtable() -> usize {
40 TAKE_FN.kernels().len() + TAKE_FROM_FN.kernels().len()
41}
42
43pub fn take(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
50 if indices.is_empty() {
51 return Ok(Canonical::empty(
52 &array
53 .dtype()
54 .union_nullability(indices.dtype().nullability()),
55 )
56 .into_array());
57 }
58
59 TAKE_FN
60 .invoke(&InvocationArgs {
61 inputs: &[array.into(), indices.into()],
62 options: &(),
63 })?
64 .unwrap_array()
65}
66
67#[doc(hidden)]
68pub struct Take;
69
70impl ComputeFnVTable for Take {
71 fn invoke(
72 &self,
73 args: &InvocationArgs,
74 kernels: &[ArcRef<dyn Kernel>],
75 ) -> VortexResult<Output> {
76 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
77
78 if indices.all_invalid() {
84 return Ok(ConstantArray::new(
85 Scalar::null(array.dtype().as_nullable()),
86 indices.len(),
87 )
88 .into_array()
89 .into());
90 }
91
92 let taken_array = take_impl(array, indices, kernels)?;
93
94 if !taken_array.is_constant() {
97 propagate_take_stats(array, &taken_array, indices)?;
98 }
99
100 Ok(taken_array.into())
101 }
102
103 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
104 let TakeArgs { array, indices } = TakeArgs::try_from(args)?;
105
106 if !indices.dtype().is_int() {
107 vortex_bail!(
108 "Take indices must be an integer type, got {}",
109 indices.dtype()
110 );
111 }
112
113 Ok(array
114 .dtype()
115 .union_nullability(indices.dtype().nullability()))
116 }
117
118 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
119 let TakeArgs { indices, .. } = TakeArgs::try_from(args)?;
120 Ok(indices.len())
121 }
122
123 fn is_elementwise(&self) -> bool {
124 false
125 }
126}
127
128fn propagate_take_stats(
129 source: &dyn Array,
130 target: &dyn Array,
131 indices: &dyn Array,
132) -> VortexResult<()> {
133 target.statistics().with_mut_typed_stats_set(|mut st| {
134 if indices.all_valid() {
135 let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
136 if is_constant == Some(Precision::Exact(true)) {
137 st.set(Stat::IsConstant, Precision::exact(true));
139 }
140 }
141 let inexact_min_max = [Stat::Min, Stat::Max]
142 .into_iter()
143 .filter_map(|stat| {
144 source
145 .statistics()
146 .get(stat)
147 .map(|v| (stat, v.map(|s| s.into_value()).into_inexact()))
148 })
149 .collect::<Vec<_>>();
150 st.combine_sets(
151 &(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
152 )
153 })
154}
155
156fn take_impl(
157 array: &dyn Array,
158 indices: &dyn Array,
159 kernels: &[ArcRef<dyn Kernel>],
160) -> VortexResult<ArrayRef> {
161 let args = InvocationArgs {
162 inputs: &[array.into(), indices.into()],
163 options: &(),
164 };
165
166 for kernel in TAKE_FROM_FN.kernels() {
168 if let Some(output) = kernel.invoke(&args)? {
169 return output.unwrap_array();
170 }
171 }
172 if let Some(output) = indices.invoke(&TAKE_FROM_FN, &args)? {
173 return output.unwrap_array();
174 }
175
176 for kernel in kernels {
178 if let Some(output) = kernel.invoke(&args)? {
179 return output.unwrap_array();
180 }
181 }
182 if let Some(output) = array.invoke(&TAKE_FN, &args)? {
183 return output.unwrap_array();
184 }
185
186 if !array.is_canonical() {
188 log::debug!("No take implementation found for {}", array.encoding_id());
189 let canonical = array.to_canonical();
190 return take(canonical.as_ref(), indices);
191 }
192
193 vortex_bail!("No take implementation found for {}", array.encoding_id());
194}
195
196struct TakeArgs<'a> {
197 array: &'a dyn Array,
198 indices: &'a dyn Array,
199}
200
201impl<'a> TryFrom<&InvocationArgs<'a>> for TakeArgs<'a> {
202 type Error = VortexError;
203
204 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
205 if value.inputs.len() != 2 {
206 vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
207 }
208 let array = value.inputs[0]
209 .array()
210 .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
211 let indices = value.inputs[1]
212 .array()
213 .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
214 Ok(Self { array, indices })
215 }
216}
217
218pub trait TakeKernel: VTable {
219 fn take(&self, array: &Self::Array, indices: &dyn Array) -> VortexResult<ArrayRef>;
226}
227
228pub struct TakeKernelRef(pub ArcRef<dyn Kernel>);
230inventory::collect!(TakeKernelRef);
231
232#[derive(Debug)]
233pub struct TakeKernelAdapter<V: VTable>(pub V);
234
235impl<V: VTable + TakeKernel> TakeKernelAdapter<V> {
236 pub const fn lift(&'static self) -> TakeKernelRef {
237 TakeKernelRef(ArcRef::new_ref(self))
238 }
239}
240
241impl<V: VTable + TakeKernel> Kernel for TakeKernelAdapter<V> {
242 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
243 let inputs = TakeArgs::try_from(args)?;
244 let Some(array) = inputs.array.as_opt::<V>() else {
245 return Ok(None);
246 };
247 Ok(Some(V::take(&self.0, array, inputs.indices)?.into()))
248 }
249}
250
251static TAKE_FROM_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
252 let compute = ComputeFn::new("take_from".into(), ArcRef::new_ref(&TakeFrom));
253 for kernel in inventory::iter::<TakeFromKernelRef> {
254 compute.register_kernel(kernel.0.clone());
255 }
256 compute
257});
258
259pub struct TakeFrom;
260
261impl ComputeFnVTable for TakeFrom {
262 fn invoke(
263 &self,
264 _args: &InvocationArgs,
265 _kernels: &[ArcRef<dyn Kernel>],
266 ) -> VortexResult<Output> {
267 vortex_bail!(
268 "TakeFrom should not be invoked directly. Its kernels are used to accelerated the Take function"
269 )
270 }
271
272 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
273 Take.return_dtype(args)
274 }
275
276 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
277 Take.return_len(args)
278 }
279
280 fn is_elementwise(&self) -> bool {
281 Take.is_elementwise()
282 }
283}
284
285pub trait TakeFromKernel: VTable {
286 fn take_from(&self, indices: &Self::Array, array: &dyn Array)
289 -> VortexResult<Option<ArrayRef>>;
290}
291
292pub struct TakeFromKernelRef(pub ArcRef<dyn Kernel>);
293inventory::collect!(TakeFromKernelRef);
294
295#[derive(Debug)]
296pub struct TakeFromKernelAdapter<V: VTable>(pub V);
297
298impl<V: VTable + TakeFromKernel> TakeFromKernelAdapter<V> {
299 pub const fn lift(&'static self) -> TakeFromKernelRef {
300 TakeFromKernelRef(ArcRef::new_ref(self))
301 }
302}
303
304impl<V: VTable + TakeFromKernel> Kernel for TakeFromKernelAdapter<V> {
305 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
306 let inputs = TakeArgs::try_from(args)?;
307 let Some(indices) = inputs.indices.as_opt::<V>() else {
308 return Ok(None);
309 };
310 Ok(V::take_from(&self.0, indices, inputs.array)?.map(Output::from))
311 }
312}