vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39
40use crate::builders::ArrayBuilder;
41use crate::{Array, ArrayRef};
42
43#[cfg(feature = "arbitrary")]
44mod arbitrary;
45mod between;
46mod boolean;
47mod cast;
48mod compare;
49#[cfg(feature = "test-harness")]
50pub mod conformance;
51mod fill_null;
52mod filter;
53mod invert;
54mod is_constant;
55mod is_sorted;
56mod like;
57mod list_contains;
58mod mask;
59mod min_max;
60mod nan_count;
61mod numeric;
62mod sum;
63mod take;
64
65/// An instance of a compute function holding the implementation vtable and a set of registered
66/// compute kernels.
67pub struct ComputeFn {
68    id: ArcRef<str>,
69    vtable: ArcRef<dyn ComputeFnVTable>,
70    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
71}
72
73impl ComputeFn {
74    /// Create a new compute function from the given [`ComputeFnVTable`].
75    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
76        Self {
77            id,
78            vtable,
79            kernels: Default::default(),
80        }
81    }
82
83    /// Returns the string identifier of the compute function.
84    pub fn id(&self) -> &ArcRef<str> {
85        &self.id
86    }
87
88    /// Register a kernel for the compute function.
89    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
90        self.kernels.write().push(kernel);
91    }
92
93    /// Invokes the compute function with the given arguments.
94    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
95        // Perform some pre-condition checks against the arguments and the function properties.
96        if self.is_elementwise() {
97            // For element-wise functions, all input arrays must be the same length.
98            if !args
99                .inputs
100                .iter()
101                .filter_map(|input| input.array())
102                .map(|array| array.len())
103                .all_equal()
104            {
105                vortex_bail!(
106                    "Compute function {} is elementwise but input arrays have different lengths",
107                    self.id
108                );
109            }
110        }
111
112        let expected_dtype = self.vtable.return_dtype(args)?;
113        let expected_len = self.vtable.return_len(args)?;
114
115        let output = self.vtable.invoke(args, &self.kernels.read())?;
116
117        if output.dtype() != &expected_dtype {
118            vortex_bail!(
119                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
120                self.id,
121                output.dtype(),
122                &expected_dtype,
123                args.inputs
124                    .iter()
125                    .filter_map(|input| input.array())
126                    .format_with(",", |array, f| f(&array.display_tree()))
127            );
128        }
129        if output.len() != expected_len {
130            vortex_bail!(
131                "Internal error: compute function {} returned a result of length {} but expected {}",
132                self.id,
133                output.len(),
134                expected_len
135            );
136        }
137
138        Ok(output)
139    }
140
141    /// Compute the return type of the function given the input arguments.
142    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
143        self.vtable.return_dtype(args)
144    }
145
146    /// Compute the return length of the function given the input arguments.
147    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
148        self.vtable.return_len(args)
149    }
150
151    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
152    pub fn is_elementwise(&self) -> bool {
153        // TODO(ngates): should this just be a constant passed in the constructor?
154        self.vtable.is_elementwise()
155    }
156
157    /// Returns the compute function's kernels.
158    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
159        self.kernels.read().to_vec()
160    }
161}
162
163/// VTable for the implementation of a compute function.
164pub trait ComputeFnVTable: 'static + Send + Sync {
165    /// Invokes the compute function entry-point with the given input arguments and options.
166    ///
167    /// The entry-point logic can short-circuit compute using statistics, update result array
168    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
169    /// to successfully compute a result.
170    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
171    -> VortexResult<Output>;
172
173    /// Computes the return type of the function given the input arguments.
174    ///
175    /// All kernel implementations will be validated to return the [`DType`] as computed here.
176    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
177
178    /// Computes the return length of the function given the input arguments.
179    ///
180    /// All kernel implementations will be validated to return the len as computed here.
181    /// Scalars are considered to have length 1.
182    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
183
184    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
185    /// input and no information is shared between elements.
186    ///
187    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
188    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
189    ///
190    /// All input arrays to an elementwise function *must* have the same length.
191    fn is_elementwise(&self) -> bool;
192}
193
194/// Arguments to a compute function invocation.
195#[derive(Clone)]
196pub struct InvocationArgs<'a> {
197    pub inputs: &'a [Input<'a>],
198    pub options: &'a dyn Options,
199}
200
201/// For unary compute functions, it's useful to just have this short-cut.
202pub struct UnaryArgs<'a, O: Options> {
203    pub array: &'a dyn Array,
204    pub options: &'a O,
205}
206
207impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
208    type Error = VortexError;
209
210    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
211        if value.inputs.len() != 1 {
212            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
213        }
214        let array = value.inputs[0]
215            .array()
216            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
217        let options =
218            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
219                vortex_err!("Expected options to be of type {}", type_name::<O>())
220            })?;
221        Ok(UnaryArgs { array, options })
222    }
223}
224
225/// For binary compute functions, it's useful to just have this short-cut.
226pub struct BinaryArgs<'a, O: Options> {
227    pub lhs: &'a dyn Array,
228    pub rhs: &'a dyn Array,
229    pub options: &'a O,
230}
231
232impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
233    type Error = VortexError;
234
235    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
236        if value.inputs.len() != 2 {
237            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
238        }
239        let lhs = value.inputs[0]
240            .array()
241            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
242        let rhs = value.inputs[1]
243            .array()
244            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
245        let options =
246            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
247                vortex_err!("Expected options to be of type {}", type_name::<O>())
248            })?;
249        Ok(BinaryArgs { lhs, rhs, options })
250    }
251}
252
253/// Input to a compute function.
254pub enum Input<'a> {
255    Scalar(&'a Scalar),
256    Array(&'a dyn Array),
257    Mask(&'a Mask),
258    Builder(&'a mut dyn ArrayBuilder),
259    DType(&'a DType),
260}
261
262impl Debug for Input<'_> {
263    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264        let mut f = f.debug_struct("Input");
265        match self {
266            Input::Scalar(scalar) => f.field("Scalar", scalar),
267            Input::Array(array) => f.field("Array", array),
268            Input::Mask(mask) => f.field("Mask", mask),
269            Input::Builder(builder) => f.field("Builder", &builder.len()),
270            Input::DType(dtype) => f.field("DType", dtype),
271        };
272        f.finish()
273    }
274}
275
276impl<'a> From<&'a dyn Array> for Input<'a> {
277    fn from(value: &'a dyn Array) -> Self {
278        Input::Array(value)
279    }
280}
281
282impl<'a> From<&'a Scalar> for Input<'a> {
283    fn from(value: &'a Scalar) -> Self {
284        Input::Scalar(value)
285    }
286}
287
288impl<'a> From<&'a Mask> for Input<'a> {
289    fn from(value: &'a Mask) -> Self {
290        Input::Mask(value)
291    }
292}
293
294impl<'a> From<&'a DType> for Input<'a> {
295    fn from(value: &'a DType) -> Self {
296        Input::DType(value)
297    }
298}
299
300impl<'a> Input<'a> {
301    pub fn scalar(&self) -> Option<&'a Scalar> {
302        match self {
303            Input::Scalar(scalar) => Some(*scalar),
304            _ => None,
305        }
306    }
307
308    pub fn array(&self) -> Option<&'a dyn Array> {
309        match self {
310            Input::Array(array) => Some(*array),
311            _ => None,
312        }
313    }
314
315    pub fn mask(&self) -> Option<&'a Mask> {
316        match self {
317            Input::Mask(mask) => Some(*mask),
318            _ => None,
319        }
320    }
321
322    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
323        match self {
324            Input::Builder(builder) => Some(*builder),
325            _ => None,
326        }
327    }
328
329    pub fn dtype(&self) -> Option<&'a DType> {
330        match self {
331            Input::DType(dtype) => Some(*dtype),
332            _ => None,
333        }
334    }
335}
336
337/// Output from a compute function.
338#[derive(Debug)]
339pub enum Output {
340    Scalar(Scalar),
341    Array(ArrayRef),
342}
343
344#[allow(clippy::len_without_is_empty)]
345impl Output {
346    pub fn dtype(&self) -> &DType {
347        match self {
348            Output::Scalar(scalar) => scalar.dtype(),
349            Output::Array(array) => array.dtype(),
350        }
351    }
352
353    pub fn len(&self) -> usize {
354        match self {
355            Output::Scalar(_) => 1,
356            Output::Array(array) => array.len(),
357        }
358    }
359
360    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
361        match self {
362            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
363            Output::Scalar(scalar) => Ok(scalar),
364        }
365    }
366
367    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
368        match self {
369            Output::Array(array) => Ok(array),
370            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
371        }
372    }
373}
374
375impl From<ArrayRef> for Output {
376    fn from(value: ArrayRef) -> Self {
377        Output::Array(value)
378    }
379}
380
381impl From<Scalar> for Output {
382    fn from(value: Scalar) -> Self {
383        Output::Scalar(value)
384    }
385}
386
387/// Options for a compute function invocation.
388pub trait Options: 'static {
389    fn as_any(&self) -> &dyn Any;
390}
391
392impl Options for () {
393    fn as_any(&self) -> &dyn Any {
394        self
395    }
396}
397
398/// Compute functions can ask arrays for compute kernels for a given invocation.
399///
400/// The kernel is invoked with the input arguments and options, and can return `None` if it is
401/// unable to compute the result for the given inputs due to missing implementation logic.
402/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
403/// is indicating that it cannot compute the result for the given inputs, and another kernel should
404/// be tried. *Not* that the given inputs are invalid for the compute function.
405///
406/// If the kernel fails to compute a result, it should return a `Some` with the error.
407pub trait Kernel: 'static + Send + Sync + Debug {
408    /// Invokes the kernel with the given input arguments and options.
409    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
410}
411
412/// Register a kernel for a compute function.
413/// See each compute function for the correct type of kernel to register.
414#[macro_export]
415macro_rules! register_kernel {
416    ($T:expr) => {
417        $crate::aliases::inventory::submit!($T);
418    };
419}