vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46mod between;
47mod boolean;
48mod cast;
49mod compare;
50#[cfg(feature = "test-harness")]
51pub mod conformance;
52mod fill_null;
53mod filter;
54mod invert;
55mod is_constant;
56mod is_sorted;
57mod like;
58mod list_contains;
59mod mask;
60mod min_max;
61mod nan_count;
62mod numeric;
63mod sum;
64mod take;
65mod zip;
66
67/// An instance of a compute function holding the implementation vtable and a set of registered
68/// compute kernels.
69pub struct ComputeFn {
70    id: ArcRef<str>,
71    vtable: ArcRef<dyn ComputeFnVTable>,
72    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
73}
74
75/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
76///
77/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
78pub fn warm_up_vtables() {
79    crate::arrow::warm_up_vtable();
80    #[allow(unused_qualifications)]
81    between::warm_up_vtable();
82    boolean::warm_up_vtable();
83    cast::warm_up_vtable();
84    compare::warm_up_vtable();
85    fill_null::warm_up_vtable();
86    filter::warm_up_vtable();
87    invert::warm_up_vtable();
88    is_constant::warm_up_vtable();
89    is_sorted::warm_up_vtable();
90    like::warm_up_vtable();
91    list_contains::warm_up_vtable();
92    mask::warm_up_vtable();
93    min_max::warm_up_vtable();
94    nan_count::warm_up_vtable();
95    numeric::warm_up_vtable();
96    sum::warm_up_vtable();
97    take::warm_up_vtable();
98    zip::warm_up_vtable();
99}
100
101impl ComputeFn {
102    /// Create a new compute function from the given [`ComputeFnVTable`].
103    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
104        Self {
105            id,
106            vtable,
107            kernels: Default::default(),
108        }
109    }
110
111    /// Returns the string identifier of the compute function.
112    pub fn id(&self) -> &ArcRef<str> {
113        &self.id
114    }
115
116    /// Register a kernel for the compute function.
117    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
118        self.kernels.write().push(kernel);
119    }
120
121    /// Invokes the compute function with the given arguments.
122    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
123        // Perform some pre-condition checks against the arguments and the function properties.
124        if self.is_elementwise() {
125            // For element-wise functions, all input arrays must be the same length.
126            if !args
127                .inputs
128                .iter()
129                .filter_map(|input| input.array())
130                .map(|array| array.len())
131                .all_equal()
132            {
133                vortex_bail!(
134                    "Compute function {} is elementwise but input arrays have different lengths",
135                    self.id
136                );
137            }
138        }
139
140        let expected_dtype = self.vtable.return_dtype(args)?;
141        let expected_len = self.vtable.return_len(args)?;
142
143        let output = self.vtable.invoke(args, &self.kernels.read())?;
144
145        if output.dtype() != &expected_dtype {
146            vortex_bail!(
147                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
148                self.id,
149                output.dtype(),
150                &expected_dtype,
151                args.inputs
152                    .iter()
153                    .filter_map(|input| input.array())
154                    .format_with(",", |array, f| f(&array.display_tree()))
155            );
156        }
157        if output.len() != expected_len {
158            vortex_bail!(
159                "Internal error: compute function {} returned a result of length {} but expected {}",
160                self.id,
161                output.len(),
162                expected_len
163            );
164        }
165
166        Ok(output)
167    }
168
169    /// Compute the return type of the function given the input arguments.
170    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
171        self.vtable.return_dtype(args)
172    }
173
174    /// Compute the return length of the function given the input arguments.
175    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
176        self.vtable.return_len(args)
177    }
178
179    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
180    pub fn is_elementwise(&self) -> bool {
181        // TODO(ngates): should this just be a constant passed in the constructor?
182        self.vtable.is_elementwise()
183    }
184
185    /// Returns the compute function's kernels.
186    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
187        self.kernels.read().to_vec()
188    }
189}
190
191/// VTable for the implementation of a compute function.
192pub trait ComputeFnVTable: 'static + Send + Sync {
193    /// Invokes the compute function entry-point with the given input arguments and options.
194    ///
195    /// The entry-point logic can short-circuit compute using statistics, update result array
196    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
197    /// to successfully compute a result.
198    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
199    -> VortexResult<Output>;
200
201    /// Computes the return type of the function given the input arguments.
202    ///
203    /// All kernel implementations will be validated to return the [`DType`] as computed here.
204    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
205
206    /// Computes the return length of the function given the input arguments.
207    ///
208    /// All kernel implementations will be validated to return the len as computed here.
209    /// Scalars are considered to have length 1.
210    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
211
212    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
213    /// input and no information is shared between elements.
214    ///
215    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
216    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
217    ///
218    /// All input arrays to an elementwise function *must* have the same length.
219    fn is_elementwise(&self) -> bool;
220}
221
222/// Arguments to a compute function invocation.
223#[derive(Clone)]
224pub struct InvocationArgs<'a> {
225    pub inputs: &'a [Input<'a>],
226    pub options: &'a dyn Options,
227}
228
229/// For unary compute functions, it's useful to just have this short-cut.
230pub struct UnaryArgs<'a, O: Options> {
231    pub array: &'a dyn Array,
232    pub options: &'a O,
233}
234
235impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
236    type Error = VortexError;
237
238    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
239        if value.inputs.len() != 1 {
240            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
241        }
242        let array = value.inputs[0]
243            .array()
244            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
245        let options =
246            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
247                vortex_err!("Expected options to be of type {}", type_name::<O>())
248            })?;
249        Ok(UnaryArgs { array, options })
250    }
251}
252
253/// For binary compute functions, it's useful to just have this short-cut.
254pub struct BinaryArgs<'a, O: Options> {
255    pub lhs: &'a dyn Array,
256    pub rhs: &'a dyn Array,
257    pub options: &'a O,
258}
259
260impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
261    type Error = VortexError;
262
263    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
264        if value.inputs.len() != 2 {
265            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
266        }
267        let lhs = value.inputs[0]
268            .array()
269            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
270        let rhs = value.inputs[1]
271            .array()
272            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
273        let options =
274            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
275                vortex_err!("Expected options to be of type {}", type_name::<O>())
276            })?;
277        Ok(BinaryArgs { lhs, rhs, options })
278    }
279}
280
281/// Input to a compute function.
282pub enum Input<'a> {
283    Scalar(&'a Scalar),
284    Array(&'a dyn Array),
285    Mask(&'a Mask),
286    Builder(&'a mut dyn ArrayBuilder),
287    DType(&'a DType),
288}
289
290impl Debug for Input<'_> {
291    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292        let mut f = f.debug_struct("Input");
293        match self {
294            Input::Scalar(scalar) => f.field("Scalar", scalar),
295            Input::Array(array) => f.field("Array", array),
296            Input::Mask(mask) => f.field("Mask", mask),
297            Input::Builder(builder) => f.field("Builder", &builder.len()),
298            Input::DType(dtype) => f.field("DType", dtype),
299        };
300        f.finish()
301    }
302}
303
304impl<'a> From<&'a dyn Array> for Input<'a> {
305    fn from(value: &'a dyn Array) -> Self {
306        Input::Array(value)
307    }
308}
309
310impl<'a> From<&'a Scalar> for Input<'a> {
311    fn from(value: &'a Scalar) -> Self {
312        Input::Scalar(value)
313    }
314}
315
316impl<'a> From<&'a Mask> for Input<'a> {
317    fn from(value: &'a Mask) -> Self {
318        Input::Mask(value)
319    }
320}
321
322impl<'a> From<&'a DType> for Input<'a> {
323    fn from(value: &'a DType) -> Self {
324        Input::DType(value)
325    }
326}
327
328impl<'a> Input<'a> {
329    pub fn scalar(&self) -> Option<&'a Scalar> {
330        if let Input::Scalar(scalar) = self {
331            Some(*scalar)
332        } else {
333            None
334        }
335    }
336
337    pub fn array(&self) -> Option<&'a dyn Array> {
338        if let Input::Array(array) = self {
339            Some(*array)
340        } else {
341            None
342        }
343    }
344
345    pub fn mask(&self) -> Option<&'a Mask> {
346        if let Input::Mask(mask) = self {
347            Some(*mask)
348        } else {
349            None
350        }
351    }
352
353    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
354        if let Input::Builder(builder) = self {
355            Some(*builder)
356        } else {
357            None
358        }
359    }
360
361    pub fn dtype(&self) -> Option<&'a DType> {
362        if let Input::DType(dtype) = self {
363            Some(*dtype)
364        } else {
365            None
366        }
367    }
368}
369
370/// Output from a compute function.
371#[derive(Debug)]
372pub enum Output {
373    Scalar(Scalar),
374    Array(ArrayRef),
375}
376
377#[allow(clippy::len_without_is_empty)]
378impl Output {
379    pub fn dtype(&self) -> &DType {
380        match self {
381            Output::Scalar(scalar) => scalar.dtype(),
382            Output::Array(array) => array.dtype(),
383        }
384    }
385
386    pub fn len(&self) -> usize {
387        match self {
388            Output::Scalar(_) => 1,
389            Output::Array(array) => array.len(),
390        }
391    }
392
393    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
394        match self {
395            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
396            Output::Scalar(scalar) => Ok(scalar),
397        }
398    }
399
400    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
401        match self {
402            Output::Array(array) => Ok(array),
403            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
404        }
405    }
406}
407
408impl From<ArrayRef> for Output {
409    fn from(value: ArrayRef) -> Self {
410        Output::Array(value)
411    }
412}
413
414impl From<Scalar> for Output {
415    fn from(value: Scalar) -> Self {
416        Output::Scalar(value)
417    }
418}
419
420/// Options for a compute function invocation.
421pub trait Options: 'static {
422    fn as_any(&self) -> &dyn Any;
423}
424
425impl Options for () {
426    fn as_any(&self) -> &dyn Any {
427        self
428    }
429}
430
431/// Compute functions can ask arrays for compute kernels for a given invocation.
432///
433/// The kernel is invoked with the input arguments and options, and can return `None` if it is
434/// unable to compute the result for the given inputs due to missing implementation logic.
435/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
436/// is indicating that it cannot compute the result for the given inputs, and another kernel should
437/// be tried. *Not* that the given inputs are invalid for the compute function.
438///
439/// If the kernel fails to compute a result, it should return a `Some` with the error.
440pub trait Kernel: 'static + Send + Sync + Debug {
441    /// Invokes the kernel with the given input arguments and options.
442    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
443}
444
445/// Register a kernel for a compute function.
446/// See each compute function for the correct type of kernel to register.
447#[macro_export]
448macro_rules! register_kernel {
449    ($T:expr) => {
450        $crate::aliases::inventory::submit!($T);
451    };
452}