vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use between::*;
19pub use boolean::*;
20pub use cast::*;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24pub use invert::*;
25pub use is_constant::*;
26pub use is_sorted::*;
27use itertools::Itertools;
28pub use like::*;
29pub use list_contains::*;
30pub use mask::*;
31pub use min_max::*;
32pub use nan_count::*;
33pub use numeric::*;
34use parking_lot::RwLock;
35pub use sum::*;
36pub use take::*;
37use vortex_dtype::DType;
38use vortex_error::VortexError;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_err;
42use vortex_mask::Mask;
43use vortex_scalar::Scalar;
44pub use zip::*;
45
46use crate::Array;
47use crate::ArrayRef;
48use crate::builders::ArrayBuilder;
49
50#[cfg(feature = "arbitrary")]
51mod arbitrary;
52mod between;
53mod boolean;
54mod cast;
55mod compare;
56#[cfg(feature = "test-harness")]
57pub mod conformance;
58mod fill_null;
59mod filter;
60mod invert;
61mod is_constant;
62mod is_sorted;
63mod like;
64mod list_contains;
65mod mask;
66mod min_max;
67mod nan_count;
68mod numeric;
69mod sum;
70mod take;
71mod zip;
72
73/// An instance of a compute function holding the implementation vtable and a set of registered
74/// compute kernels.
75pub struct ComputeFn {
76    id: ArcRef<str>,
77    vtable: ArcRef<dyn ComputeFnVTable>,
78    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
79}
80
81/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
82///
83/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
84pub fn warm_up_vtables() {
85    crate::arrow::warm_up_vtable();
86    #[allow(unused_qualifications)]
87    between::warm_up_vtable();
88    boolean::warm_up_vtable();
89    cast::warm_up_vtable();
90    compare::warm_up_vtable();
91    fill_null::warm_up_vtable();
92    filter::warm_up_vtable();
93    invert::warm_up_vtable();
94    is_constant::warm_up_vtable();
95    is_sorted::warm_up_vtable();
96    like::warm_up_vtable();
97    list_contains::warm_up_vtable();
98    mask::warm_up_vtable();
99    min_max::warm_up_vtable();
100    nan_count::warm_up_vtable();
101    numeric::warm_up_vtable();
102    sum::warm_up_vtable();
103    take::warm_up_vtable();
104    zip::warm_up_vtable();
105}
106
107impl ComputeFn {
108    /// Create a new compute function from the given [`ComputeFnVTable`].
109    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
110        Self {
111            id,
112            vtable,
113            kernels: Default::default(),
114        }
115    }
116
117    /// Returns the string identifier of the compute function.
118    pub fn id(&self) -> &ArcRef<str> {
119        &self.id
120    }
121
122    /// Register a kernel for the compute function.
123    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
124        self.kernels.write().push(kernel);
125    }
126
127    /// Invokes the compute function with the given arguments.
128    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
129        // Perform some pre-condition checks against the arguments and the function properties.
130        if self.is_elementwise() {
131            // For element-wise functions, all input arrays must be the same length.
132            if !args
133                .inputs
134                .iter()
135                .filter_map(|input| input.array())
136                .map(|array| array.len())
137                .all_equal()
138            {
139                vortex_bail!(
140                    "Compute function {} is elementwise but input arrays have different lengths",
141                    self.id
142                );
143            }
144        }
145
146        let expected_dtype = self.vtable.return_dtype(args)?;
147        let expected_len = self.vtable.return_len(args)?;
148
149        let output = self.vtable.invoke(args, &self.kernels.read())?;
150
151        if output.dtype() != &expected_dtype {
152            vortex_bail!(
153                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
154                self.id,
155                output.dtype(),
156                &expected_dtype,
157                args.inputs
158                    .iter()
159                    .filter_map(|input| input.array())
160                    .format_with(",", |array, f| f(&array.display_tree()))
161            );
162        }
163        if output.len() != expected_len {
164            vortex_bail!(
165                "Internal error: compute function {} returned a result of length {} but expected {}",
166                self.id,
167                output.len(),
168                expected_len
169            );
170        }
171
172        Ok(output)
173    }
174
175    /// Compute the return type of the function given the input arguments.
176    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
177        self.vtable.return_dtype(args)
178    }
179
180    /// Compute the return length of the function given the input arguments.
181    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
182        self.vtable.return_len(args)
183    }
184
185    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
186    pub fn is_elementwise(&self) -> bool {
187        // TODO(ngates): should this just be a constant passed in the constructor?
188        self.vtable.is_elementwise()
189    }
190
191    /// Returns the compute function's kernels.
192    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
193        self.kernels.read().to_vec()
194    }
195}
196
197/// VTable for the implementation of a compute function.
198pub trait ComputeFnVTable: 'static + Send + Sync {
199    /// Invokes the compute function entry-point with the given input arguments and options.
200    ///
201    /// The entry-point logic can short-circuit compute using statistics, update result array
202    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
203    /// to successfully compute a result.
204    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
205    -> VortexResult<Output>;
206
207    /// Computes the return type of the function given the input arguments.
208    ///
209    /// All kernel implementations will be validated to return the [`DType`] as computed here.
210    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
211
212    /// Computes the return length of the function given the input arguments.
213    ///
214    /// All kernel implementations will be validated to return the len as computed here.
215    /// Scalars are considered to have length 1.
216    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
217
218    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
219    /// input and no information is shared between elements.
220    ///
221    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
222    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
223    ///
224    /// All input arrays to an elementwise function *must* have the same length.
225    fn is_elementwise(&self) -> bool;
226}
227
228/// Arguments to a compute function invocation.
229#[derive(Clone)]
230pub struct InvocationArgs<'a> {
231    pub inputs: &'a [Input<'a>],
232    pub options: &'a dyn Options,
233}
234
235/// For unary compute functions, it's useful to just have this short-cut.
236pub struct UnaryArgs<'a, O: Options> {
237    pub array: &'a dyn Array,
238    pub options: &'a O,
239}
240
241impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
242    type Error = VortexError;
243
244    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
245        if value.inputs.len() != 1 {
246            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
247        }
248        let array = value.inputs[0]
249            .array()
250            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
251        let options =
252            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
253                vortex_err!("Expected options to be of type {}", type_name::<O>())
254            })?;
255        Ok(UnaryArgs { array, options })
256    }
257}
258
259/// For binary compute functions, it's useful to just have this short-cut.
260pub struct BinaryArgs<'a, O: Options> {
261    pub lhs: &'a dyn Array,
262    pub rhs: &'a dyn Array,
263    pub options: &'a O,
264}
265
266impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
267    type Error = VortexError;
268
269    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
270        if value.inputs.len() != 2 {
271            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
272        }
273        let lhs = value.inputs[0]
274            .array()
275            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
276        let rhs = value.inputs[1]
277            .array()
278            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
279        let options =
280            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
281                vortex_err!("Expected options to be of type {}", type_name::<O>())
282            })?;
283        Ok(BinaryArgs { lhs, rhs, options })
284    }
285}
286
287/// Input to a compute function.
288pub enum Input<'a> {
289    Scalar(&'a Scalar),
290    Array(&'a dyn Array),
291    Mask(&'a Mask),
292    Builder(&'a mut dyn ArrayBuilder),
293    DType(&'a DType),
294}
295
296impl Debug for Input<'_> {
297    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298        let mut f = f.debug_struct("Input");
299        match self {
300            Input::Scalar(scalar) => f.field("Scalar", scalar),
301            Input::Array(array) => f.field("Array", array),
302            Input::Mask(mask) => f.field("Mask", mask),
303            Input::Builder(builder) => f.field("Builder", &builder.len()),
304            Input::DType(dtype) => f.field("DType", dtype),
305        };
306        f.finish()
307    }
308}
309
310impl<'a> From<&'a dyn Array> for Input<'a> {
311    fn from(value: &'a dyn Array) -> Self {
312        Input::Array(value)
313    }
314}
315
316impl<'a> From<&'a Scalar> for Input<'a> {
317    fn from(value: &'a Scalar) -> Self {
318        Input::Scalar(value)
319    }
320}
321
322impl<'a> From<&'a Mask> for Input<'a> {
323    fn from(value: &'a Mask) -> Self {
324        Input::Mask(value)
325    }
326}
327
328impl<'a> From<&'a DType> for Input<'a> {
329    fn from(value: &'a DType) -> Self {
330        Input::DType(value)
331    }
332}
333
334impl<'a> Input<'a> {
335    pub fn scalar(&self) -> Option<&'a Scalar> {
336        if let Input::Scalar(scalar) = self {
337            Some(*scalar)
338        } else {
339            None
340        }
341    }
342
343    pub fn array(&self) -> Option<&'a dyn Array> {
344        if let Input::Array(array) = self {
345            Some(*array)
346        } else {
347            None
348        }
349    }
350
351    pub fn mask(&self) -> Option<&'a Mask> {
352        if let Input::Mask(mask) = self {
353            Some(*mask)
354        } else {
355            None
356        }
357    }
358
359    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
360        if let Input::Builder(builder) = self {
361            Some(*builder)
362        } else {
363            None
364        }
365    }
366
367    pub fn dtype(&self) -> Option<&'a DType> {
368        if let Input::DType(dtype) = self {
369            Some(*dtype)
370        } else {
371            None
372        }
373    }
374}
375
376/// Output from a compute function.
377#[derive(Debug)]
378pub enum Output {
379    Scalar(Scalar),
380    Array(ArrayRef),
381}
382
383#[expect(
384    clippy::len_without_is_empty,
385    reason = "Output is always non-empty (scalar has len 1)"
386)]
387impl Output {
388    pub fn dtype(&self) -> &DType {
389        match self {
390            Output::Scalar(scalar) => scalar.dtype(),
391            Output::Array(array) => array.dtype(),
392        }
393    }
394
395    pub fn len(&self) -> usize {
396        match self {
397            Output::Scalar(_) => 1,
398            Output::Array(array) => array.len(),
399        }
400    }
401
402    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
403        match self {
404            Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
405            Output::Scalar(scalar) => Ok(scalar),
406        }
407    }
408
409    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
410        match self {
411            Output::Array(array) => Ok(array),
412            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
413        }
414    }
415}
416
417impl From<ArrayRef> for Output {
418    fn from(value: ArrayRef) -> Self {
419        Output::Array(value)
420    }
421}
422
423impl From<Scalar> for Output {
424    fn from(value: Scalar) -> Self {
425        Output::Scalar(value)
426    }
427}
428
429/// Options for a compute function invocation.
430pub trait Options: 'static {
431    fn as_any(&self) -> &dyn Any;
432}
433
434impl Options for () {
435    fn as_any(&self) -> &dyn Any {
436        self
437    }
438}
439
440/// Compute functions can ask arrays for compute kernels for a given invocation.
441///
442/// The kernel is invoked with the input arguments and options, and can return `None` if it is
443/// unable to compute the result for the given inputs due to missing implementation logic.
444/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
445/// is indicating that it cannot compute the result for the given inputs, and another kernel should
446/// be tried. *Not* that the given inputs are invalid for the compute function.
447///
448/// If the kernel fails to compute a result, it should return a `Some` with the error.
449pub trait Kernel: 'static + Send + Sync + Debug {
450    /// Invokes the kernel with the given input arguments and options.
451    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
452}
453
454/// Register a kernel for a compute function.
455/// See each compute function for the correct type of kernel to register.
456#[macro_export]
457macro_rules! register_kernel {
458    ($T:expr) => {
459        $crate::aliases::inventory::submit!($T);
460    };
461}