Skip to main content

vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use boolean::*;
19#[expect(deprecated)]
20pub use cast::cast;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24#[expect(deprecated)]
25pub use invert::invert;
26pub use is_constant::*;
27pub use is_sorted::*;
28use itertools::Itertools;
29pub use list_contains::*;
30pub use mask::*;
31pub use min_max::*;
32pub use nan_count::*;
33pub use numeric::*;
34use parking_lot::RwLock;
35pub use sum::*;
36use vortex_dtype::DType;
37use vortex_error::VortexError;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_err;
41use vortex_mask::Mask;
42pub use zip::*;
43
44use crate::Array;
45use crate::ArrayRef;
46use crate::builders::ArrayBuilder;
47pub use crate::expr::BetweenExecuteAdaptor;
48pub use crate::expr::BetweenKernel;
49pub use crate::expr::BetweenReduce;
50pub use crate::expr::BetweenReduceAdaptor;
51pub use crate::expr::CastExecuteAdaptor;
52pub use crate::expr::CastKernel;
53pub use crate::expr::CastReduce;
54pub use crate::expr::CastReduceAdaptor;
55pub use crate::expr::FillNullExecuteAdaptor;
56pub use crate::expr::FillNullKernel;
57pub use crate::expr::FillNullReduce;
58pub use crate::expr::FillNullReduceAdaptor;
59pub use crate::expr::MaskExecuteAdaptor;
60pub use crate::expr::MaskKernel;
61pub use crate::expr::MaskReduce;
62pub use crate::expr::MaskReduceAdaptor;
63pub use crate::expr::NotExecuteAdaptor;
64pub use crate::expr::NotKernel;
65pub use crate::expr::NotReduce;
66pub use crate::expr::NotReduceAdaptor;
67use crate::scalar::Scalar;
68
69#[cfg(feature = "arbitrary")]
70mod arbitrary;
71mod boolean;
72mod cast;
73mod compare;
74#[cfg(feature = "_test-harness")]
75pub mod conformance;
76mod fill_null;
77mod filter;
78mod invert;
79mod is_constant;
80mod is_sorted;
81mod list_contains;
82mod mask;
83mod min_max;
84mod nan_count;
85mod numeric;
86mod sum;
87mod zip;
88
89/// An instance of a compute function holding the implementation vtable and a set of registered
90/// compute kernels.
91pub struct ComputeFn {
92    id: ArcRef<str>,
93    vtable: ArcRef<dyn ComputeFnVTable>,
94    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
95}
96
97/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
98///
99/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
100pub fn warm_up_vtables() {
101    #[allow(unused_qualifications)]
102    is_constant::warm_up_vtable();
103    is_sorted::warm_up_vtable();
104    list_contains::warm_up_vtable();
105    min_max::warm_up_vtable();
106    nan_count::warm_up_vtable();
107    sum::warm_up_vtable();
108}
109
110impl ComputeFn {
111    /// Create a new compute function from the given [`ComputeFnVTable`].
112    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
113        Self {
114            id,
115            vtable,
116            kernels: Default::default(),
117        }
118    }
119
120    /// Returns the string identifier of the compute function.
121    pub fn id(&self) -> &ArcRef<str> {
122        &self.id
123    }
124
125    /// Register a kernel for the compute function.
126    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
127        self.kernels.write().push(kernel);
128    }
129
130    /// Invokes the compute function with the given arguments.
131    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
132        // Perform some pre-condition checks against the arguments and the function properties.
133        if self.is_elementwise() {
134            // For element-wise functions, all input arrays must be the same length.
135            if !args
136                .inputs
137                .iter()
138                .filter_map(|input| input.array())
139                .map(|array| array.len())
140                .all_equal()
141            {
142                vortex_bail!(
143                    "Compute function {} is elementwise but input arrays have different lengths",
144                    self.id
145                );
146            }
147        }
148
149        let expected_dtype = self.vtable.return_dtype(args)?;
150        let expected_len = self.vtable.return_len(args)?;
151
152        let output = self.vtable.invoke(args, &self.kernels.read())?;
153
154        if output.dtype() != &expected_dtype {
155            vortex_bail!(
156                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
157                self.id,
158                output.dtype(),
159                &expected_dtype,
160                args.inputs
161                    .iter()
162                    .filter_map(|input| input.array())
163                    .format_with(",", |array, f| f(&array.encoding_id()))
164            );
165        }
166        if output.len() != expected_len {
167            vortex_bail!(
168                "Internal error: compute function {} returned a result of length {} but expected {}",
169                self.id,
170                output.len(),
171                expected_len
172            );
173        }
174
175        Ok(output)
176    }
177
178    /// Compute the return type of the function given the input arguments.
179    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
180        self.vtable.return_dtype(args)
181    }
182
183    /// Compute the return length of the function given the input arguments.
184    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
185        self.vtable.return_len(args)
186    }
187
188    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
189    pub fn is_elementwise(&self) -> bool {
190        // TODO(ngates): should this just be a constant passed in the constructor?
191        self.vtable.is_elementwise()
192    }
193
194    /// Returns the compute function's kernels.
195    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
196        self.kernels.read().to_vec()
197    }
198}
199
200/// VTable for the implementation of a compute function.
201pub trait ComputeFnVTable: 'static + Send + Sync {
202    /// Invokes the compute function entry-point with the given input arguments and options.
203    ///
204    /// The entry-point logic can short-circuit compute using statistics, update result array
205    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
206    /// to successfully compute a result.
207    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
208    -> VortexResult<Output>;
209
210    /// Computes the return type of the function given the input arguments.
211    ///
212    /// All kernel implementations will be validated to return the [`DType`] as computed here.
213    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
214
215    /// Computes the return length of the function given the input arguments.
216    ///
217    /// All kernel implementations will be validated to return the len as computed here.
218    /// Scalars are considered to have length 1.
219    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
220
221    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
222    /// input and no information is shared between elements.
223    ///
224    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
225    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
226    ///
227    /// All input arrays to an elementwise function *must* have the same length.
228    fn is_elementwise(&self) -> bool;
229}
230
231/// Arguments to a compute function invocation.
232#[derive(Clone)]
233pub struct InvocationArgs<'a> {
234    pub inputs: &'a [Input<'a>],
235    pub options: &'a dyn Options,
236}
237
238/// For unary compute functions, it's useful to just have this short-cut.
239pub struct UnaryArgs<'a, O: Options> {
240    pub array: &'a dyn Array,
241    pub options: &'a O,
242}
243
244impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
245    type Error = VortexError;
246
247    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
248        if value.inputs.len() != 1 {
249            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
250        }
251        let array = value.inputs[0]
252            .array()
253            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
254        let options =
255            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
256                vortex_err!("Expected options to be of type {}", type_name::<O>())
257            })?;
258        Ok(UnaryArgs { array, options })
259    }
260}
261
262/// For binary compute functions, it's useful to just have this short-cut.
263pub struct BinaryArgs<'a, O: Options> {
264    pub lhs: &'a dyn Array,
265    pub rhs: &'a dyn Array,
266    pub options: &'a O,
267}
268
269impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
270    type Error = VortexError;
271
272    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
273        if value.inputs.len() != 2 {
274            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
275        }
276        let lhs = value.inputs[0]
277            .array()
278            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
279        let rhs = value.inputs[1]
280            .array()
281            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
282        let options =
283            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
284                vortex_err!("Expected options to be of type {}", type_name::<O>())
285            })?;
286        Ok(BinaryArgs { lhs, rhs, options })
287    }
288}
289
290/// Input to a compute function.
291pub enum Input<'a> {
292    Scalar(&'a Scalar),
293    Array(&'a dyn Array),
294    Mask(&'a Mask),
295    Builder(&'a mut dyn ArrayBuilder),
296    DType(&'a DType),
297}
298
299impl Debug for Input<'_> {
300    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
301        let mut f = f.debug_struct("Input");
302        match self {
303            Input::Scalar(scalar) => f.field("Scalar", scalar),
304            Input::Array(array) => f.field("Array", array),
305            Input::Mask(mask) => f.field("Mask", mask),
306            Input::Builder(builder) => f.field("Builder", &builder.len()),
307            Input::DType(dtype) => f.field("DType", dtype),
308        };
309        f.finish()
310    }
311}
312
313impl<'a> From<&'a dyn Array> for Input<'a> {
314    fn from(value: &'a dyn Array) -> Self {
315        Input::Array(value)
316    }
317}
318
319impl<'a> From<&'a Scalar> for Input<'a> {
320    fn from(value: &'a Scalar) -> Self {
321        Input::Scalar(value)
322    }
323}
324
325impl<'a> From<&'a Mask> for Input<'a> {
326    fn from(value: &'a Mask) -> Self {
327        Input::Mask(value)
328    }
329}
330
331impl<'a> From<&'a DType> for Input<'a> {
332    fn from(value: &'a DType) -> Self {
333        Input::DType(value)
334    }
335}
336
337impl<'a> Input<'a> {
338    pub fn scalar(&self) -> Option<&'a Scalar> {
339        if let Input::Scalar(scalar) = self {
340            Some(*scalar)
341        } else {
342            None
343        }
344    }
345
346    pub fn array(&self) -> Option<&'a dyn Array> {
347        if let Input::Array(array) = self {
348            Some(*array)
349        } else {
350            None
351        }
352    }
353
354    pub fn mask(&self) -> Option<&'a Mask> {
355        if let Input::Mask(mask) = self {
356            Some(*mask)
357        } else {
358            None
359        }
360    }
361
362    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
363        if let Input::Builder(builder) = self {
364            Some(*builder)
365        } else {
366            None
367        }
368    }
369
370    pub fn dtype(&self) -> Option<&'a DType> {
371        if let Input::DType(dtype) = self {
372            Some(*dtype)
373        } else {
374            None
375        }
376    }
377}
378
379/// Output from a compute function.
380#[derive(Debug)]
381pub enum Output {
382    Scalar(Scalar),
383    Array(ArrayRef),
384}
385
386#[expect(
387    clippy::len_without_is_empty,
388    reason = "Output is always non-empty (scalar has len 1)"
389)]
390impl Output {
391    pub fn dtype(&self) -> &DType {
392        match self {
393            Output::Scalar(scalar) => scalar.dtype(),
394            Output::Array(array) => array.dtype(),
395        }
396    }
397
398    pub fn len(&self) -> usize {
399        match self {
400            Output::Scalar(_) => 1,
401            Output::Array(array) => array.len(),
402        }
403    }
404
405    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
406        match self {
407            Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
408            Output::Scalar(scalar) => Ok(scalar),
409        }
410    }
411
412    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
413        match self {
414            Output::Array(array) => Ok(array),
415            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
416        }
417    }
418}
419
420impl From<ArrayRef> for Output {
421    fn from(value: ArrayRef) -> Self {
422        Output::Array(value)
423    }
424}
425
426impl From<Scalar> for Output {
427    fn from(value: Scalar) -> Self {
428        Output::Scalar(value)
429    }
430}
431
432/// Options for a compute function invocation.
433pub trait Options: 'static {
434    fn as_any(&self) -> &dyn Any;
435}
436
437impl Options for () {
438    fn as_any(&self) -> &dyn Any {
439        self
440    }
441}
442
443/// Compute functions can ask arrays for compute kernels for a given invocation.
444///
445/// The kernel is invoked with the input arguments and options, and can return `None` if it is
446/// unable to compute the result for the given inputs due to missing implementation logic.
447/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
448/// is indicating that it cannot compute the result for the given inputs, and another kernel should
449/// be tried. *Not* that the given inputs are invalid for the compute function.
450///
451/// If the kernel fails to compute a result, it should return a `Some` with the error.
452pub trait Kernel: 'static + Send + Sync + Debug {
453    /// Invokes the kernel with the given input arguments and options.
454    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
455}
456
457/// Register a kernel for a compute function.
458/// See each compute function for the correct type of kernel to register.
459#[macro_export]
460macro_rules! register_kernel {
461    ($T:expr) => {
462        $crate::aliases::inventory::submit!($T);
463    };
464}