vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46pub mod arrays;
47mod between;
48mod boolean;
49mod cast;
50mod compare;
51#[cfg(feature = "test-harness")]
52pub mod conformance;
53mod fill_null;
54mod filter;
55mod invert;
56mod is_constant;
57mod is_sorted;
58mod like;
59mod list_contains;
60mod mask;
61mod min_max;
62mod nan_count;
63mod numeric;
64mod sum;
65mod take;
66mod zip;
67
68/// An instance of a compute function holding the implementation vtable and a set of registered
69/// compute kernels.
70pub struct ComputeFn {
71    id: ArcRef<str>,
72    vtable: ArcRef<dyn ComputeFnVTable>,
73    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
74}
75
76/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
77///
78/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
79pub fn warm_up_vtables() {
80    crate::arrow::warm_up_vtable();
81    #[allow(unused_qualifications)]
82    between::warm_up_vtable();
83    boolean::warm_up_vtable();
84    cast::warm_up_vtable();
85    compare::warm_up_vtable();
86    fill_null::warm_up_vtable();
87    filter::warm_up_vtable();
88    invert::warm_up_vtable();
89    is_constant::warm_up_vtable();
90    is_sorted::warm_up_vtable();
91    like::warm_up_vtable();
92    list_contains::warm_up_vtable();
93    mask::warm_up_vtable();
94    min_max::warm_up_vtable();
95    nan_count::warm_up_vtable();
96    numeric::warm_up_vtable();
97    sum::warm_up_vtable();
98    take::warm_up_vtable();
99    zip::warm_up_vtable();
100}
101
102impl ComputeFn {
103    /// Create a new compute function from the given [`ComputeFnVTable`].
104    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
105        Self {
106            id,
107            vtable,
108            kernels: Default::default(),
109        }
110    }
111
112    /// Returns the string identifier of the compute function.
113    pub fn id(&self) -> &ArcRef<str> {
114        &self.id
115    }
116
117    /// Register a kernel for the compute function.
118    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
119        self.kernels.write().push(kernel);
120    }
121
122    /// Invokes the compute function with the given arguments.
123    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
124        // Perform some pre-condition checks against the arguments and the function properties.
125        if self.is_elementwise() {
126            // For element-wise functions, all input arrays must be the same length.
127            if !args
128                .inputs
129                .iter()
130                .filter_map(|input| input.array())
131                .map(|array| array.len())
132                .all_equal()
133            {
134                vortex_bail!(
135                    "Compute function {} is elementwise but input arrays have different lengths",
136                    self.id
137                );
138            }
139        }
140
141        let expected_dtype = self.vtable.return_dtype(args)?;
142        let expected_len = self.vtable.return_len(args)?;
143
144        let output = self.vtable.invoke(args, &self.kernels.read())?;
145
146        if output.dtype() != &expected_dtype {
147            vortex_bail!(
148                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
149                self.id,
150                output.dtype(),
151                &expected_dtype,
152                args.inputs
153                    .iter()
154                    .filter_map(|input| input.array())
155                    .format_with(",", |array, f| f(&array.display_tree()))
156            );
157        }
158        if output.len() != expected_len {
159            vortex_bail!(
160                "Internal error: compute function {} returned a result of length {} but expected {}",
161                self.id,
162                output.len(),
163                expected_len
164            );
165        }
166
167        Ok(output)
168    }
169
170    /// Compute the return type of the function given the input arguments.
171    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
172        self.vtable.return_dtype(args)
173    }
174
175    /// Compute the return length of the function given the input arguments.
176    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
177        self.vtable.return_len(args)
178    }
179
180    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
181    pub fn is_elementwise(&self) -> bool {
182        // TODO(ngates): should this just be a constant passed in the constructor?
183        self.vtable.is_elementwise()
184    }
185
186    /// Returns the compute function's kernels.
187    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
188        self.kernels.read().to_vec()
189    }
190}
191
192/// VTable for the implementation of a compute function.
193pub trait ComputeFnVTable: 'static + Send + Sync {
194    /// Invokes the compute function entry-point with the given input arguments and options.
195    ///
196    /// The entry-point logic can short-circuit compute using statistics, update result array
197    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
198    /// to successfully compute a result.
199    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
200    -> VortexResult<Output>;
201
202    /// Computes the return type of the function given the input arguments.
203    ///
204    /// All kernel implementations will be validated to return the [`DType`] as computed here.
205    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
206
207    /// Computes the return length of the function given the input arguments.
208    ///
209    /// All kernel implementations will be validated to return the len as computed here.
210    /// Scalars are considered to have length 1.
211    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
212
213    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
214    /// input and no information is shared between elements.
215    ///
216    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
217    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
218    ///
219    /// All input arrays to an elementwise function *must* have the same length.
220    fn is_elementwise(&self) -> bool;
221}
222
223/// Arguments to a compute function invocation.
224#[derive(Clone)]
225pub struct InvocationArgs<'a> {
226    pub inputs: &'a [Input<'a>],
227    pub options: &'a dyn Options,
228}
229
230/// For unary compute functions, it's useful to just have this short-cut.
231pub struct UnaryArgs<'a, O: Options> {
232    pub array: &'a dyn Array,
233    pub options: &'a O,
234}
235
236impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
237    type Error = VortexError;
238
239    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
240        if value.inputs.len() != 1 {
241            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
242        }
243        let array = value.inputs[0]
244            .array()
245            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
246        let options =
247            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
248                vortex_err!("Expected options to be of type {}", type_name::<O>())
249            })?;
250        Ok(UnaryArgs { array, options })
251    }
252}
253
254/// For binary compute functions, it's useful to just have this short-cut.
255pub struct BinaryArgs<'a, O: Options> {
256    pub lhs: &'a dyn Array,
257    pub rhs: &'a dyn Array,
258    pub options: &'a O,
259}
260
261impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
262    type Error = VortexError;
263
264    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
265        if value.inputs.len() != 2 {
266            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
267        }
268        let lhs = value.inputs[0]
269            .array()
270            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
271        let rhs = value.inputs[1]
272            .array()
273            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
274        let options =
275            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
276                vortex_err!("Expected options to be of type {}", type_name::<O>())
277            })?;
278        Ok(BinaryArgs { lhs, rhs, options })
279    }
280}
281
282/// Input to a compute function.
283pub enum Input<'a> {
284    Scalar(&'a Scalar),
285    Array(&'a dyn Array),
286    Mask(&'a Mask),
287    Builder(&'a mut dyn ArrayBuilder),
288    DType(&'a DType),
289}
290
291impl Debug for Input<'_> {
292    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        let mut f = f.debug_struct("Input");
294        match self {
295            Input::Scalar(scalar) => f.field("Scalar", scalar),
296            Input::Array(array) => f.field("Array", array),
297            Input::Mask(mask) => f.field("Mask", mask),
298            Input::Builder(builder) => f.field("Builder", &builder.len()),
299            Input::DType(dtype) => f.field("DType", dtype),
300        };
301        f.finish()
302    }
303}
304
305impl<'a> From<&'a dyn Array> for Input<'a> {
306    fn from(value: &'a dyn Array) -> Self {
307        Input::Array(value)
308    }
309}
310
311impl<'a> From<&'a Scalar> for Input<'a> {
312    fn from(value: &'a Scalar) -> Self {
313        Input::Scalar(value)
314    }
315}
316
317impl<'a> From<&'a Mask> for Input<'a> {
318    fn from(value: &'a Mask) -> Self {
319        Input::Mask(value)
320    }
321}
322
323impl<'a> From<&'a DType> for Input<'a> {
324    fn from(value: &'a DType) -> Self {
325        Input::DType(value)
326    }
327}
328
329impl<'a> Input<'a> {
330    pub fn scalar(&self) -> Option<&'a Scalar> {
331        if let Input::Scalar(scalar) = self {
332            Some(*scalar)
333        } else {
334            None
335        }
336    }
337
338    pub fn array(&self) -> Option<&'a dyn Array> {
339        if let Input::Array(array) = self {
340            Some(*array)
341        } else {
342            None
343        }
344    }
345
346    pub fn mask(&self) -> Option<&'a Mask> {
347        if let Input::Mask(mask) = self {
348            Some(*mask)
349        } else {
350            None
351        }
352    }
353
354    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
355        if let Input::Builder(builder) = self {
356            Some(*builder)
357        } else {
358            None
359        }
360    }
361
362    pub fn dtype(&self) -> Option<&'a DType> {
363        if let Input::DType(dtype) = self {
364            Some(*dtype)
365        } else {
366            None
367        }
368    }
369}
370
371/// Output from a compute function.
372#[derive(Debug)]
373pub enum Output {
374    Scalar(Scalar),
375    Array(ArrayRef),
376}
377
378#[allow(clippy::len_without_is_empty)]
379impl Output {
380    pub fn dtype(&self) -> &DType {
381        match self {
382            Output::Scalar(scalar) => scalar.dtype(),
383            Output::Array(array) => array.dtype(),
384        }
385    }
386
387    pub fn len(&self) -> usize {
388        match self {
389            Output::Scalar(_) => 1,
390            Output::Array(array) => array.len(),
391        }
392    }
393
394    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
395        match self {
396            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
397            Output::Scalar(scalar) => Ok(scalar),
398        }
399    }
400
401    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
402        match self {
403            Output::Array(array) => Ok(array),
404            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
405        }
406    }
407}
408
409impl From<ArrayRef> for Output {
410    fn from(value: ArrayRef) -> Self {
411        Output::Array(value)
412    }
413}
414
415impl From<Scalar> for Output {
416    fn from(value: Scalar) -> Self {
417        Output::Scalar(value)
418    }
419}
420
421/// Options for a compute function invocation.
422pub trait Options: 'static {
423    fn as_any(&self) -> &dyn Any;
424}
425
426impl Options for () {
427    fn as_any(&self) -> &dyn Any {
428        self
429    }
430}
431
432/// Compute functions can ask arrays for compute kernels for a given invocation.
433///
434/// The kernel is invoked with the input arguments and options, and can return `None` if it is
435/// unable to compute the result for the given inputs due to missing implementation logic.
436/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
437/// is indicating that it cannot compute the result for the given inputs, and another kernel should
438/// be tried. *Not* that the given inputs are invalid for the compute function.
439///
440/// If the kernel fails to compute a result, it should return a `Some` with the error.
441pub trait Kernel: 'static + Send + Sync + Debug {
442    /// Invokes the kernel with the given input arguments and options.
443    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
444}
445
446/// Register a kernel for a compute function.
447/// See each compute function for the correct type of kernel to register.
448#[macro_export]
449macro_rules! register_kernel {
450    ($T:expr) => {
451        $crate::aliases::inventory::submit!($T);
452    };
453}