vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46mod between;
47mod boolean;
48mod cast;
49mod compare;
50#[cfg(feature = "test-harness")]
51pub mod conformance;
52mod fill_null;
53mod filter;
54mod invert;
55mod is_constant;
56mod is_sorted;
57mod like;
58mod list_contains;
59mod mask;
60mod min_max;
61mod nan_count;
62mod numeric;
63mod sum;
64mod take;
65mod zip;
66
67/// An instance of a compute function holding the implementation vtable and a set of registered
68/// compute kernels.
69pub struct ComputeFn {
70    id: ArcRef<str>,
71    vtable: ArcRef<dyn ComputeFnVTable>,
72    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
73}
74
75impl ComputeFn {
76    /// Create a new compute function from the given [`ComputeFnVTable`].
77    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
78        Self {
79            id,
80            vtable,
81            kernels: Default::default(),
82        }
83    }
84
85    /// Returns the string identifier of the compute function.
86    pub fn id(&self) -> &ArcRef<str> {
87        &self.id
88    }
89
90    /// Register a kernel for the compute function.
91    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
92        self.kernels.write().push(kernel);
93    }
94
95    /// Invokes the compute function with the given arguments.
96    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
97        // Perform some pre-condition checks against the arguments and the function properties.
98        if self.is_elementwise() {
99            // For element-wise functions, all input arrays must be the same length.
100            if !args
101                .inputs
102                .iter()
103                .filter_map(|input| input.array())
104                .map(|array| array.len())
105                .all_equal()
106            {
107                vortex_bail!(
108                    "Compute function {} is elementwise but input arrays have different lengths",
109                    self.id
110                );
111            }
112        }
113
114        let expected_dtype = self.vtable.return_dtype(args)?;
115        let expected_len = self.vtable.return_len(args)?;
116
117        let output = self.vtable.invoke(args, &self.kernels.read())?;
118
119        if output.dtype() != &expected_dtype {
120            vortex_bail!(
121                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
122                self.id,
123                output.dtype(),
124                &expected_dtype,
125                args.inputs
126                    .iter()
127                    .filter_map(|input| input.array())
128                    .format_with(",", |array, f| f(&array.display_tree()))
129            );
130        }
131        if output.len() != expected_len {
132            vortex_bail!(
133                "Internal error: compute function {} returned a result of length {} but expected {}",
134                self.id,
135                output.len(),
136                expected_len
137            );
138        }
139
140        Ok(output)
141    }
142
143    /// Compute the return type of the function given the input arguments.
144    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
145        self.vtable.return_dtype(args)
146    }
147
148    /// Compute the return length of the function given the input arguments.
149    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
150        self.vtable.return_len(args)
151    }
152
153    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
154    pub fn is_elementwise(&self) -> bool {
155        // TODO(ngates): should this just be a constant passed in the constructor?
156        self.vtable.is_elementwise()
157    }
158
159    /// Returns the compute function's kernels.
160    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
161        self.kernels.read().to_vec()
162    }
163}
164
165/// VTable for the implementation of a compute function.
166pub trait ComputeFnVTable: 'static + Send + Sync {
167    /// Invokes the compute function entry-point with the given input arguments and options.
168    ///
169    /// The entry-point logic can short-circuit compute using statistics, update result array
170    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
171    /// to successfully compute a result.
172    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
173    -> VortexResult<Output>;
174
175    /// Computes the return type of the function given the input arguments.
176    ///
177    /// All kernel implementations will be validated to return the [`DType`] as computed here.
178    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
179
180    /// Computes the return length of the function given the input arguments.
181    ///
182    /// All kernel implementations will be validated to return the len as computed here.
183    /// Scalars are considered to have length 1.
184    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
185
186    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
187    /// input and no information is shared between elements.
188    ///
189    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
190    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
191    ///
192    /// All input arrays to an elementwise function *must* have the same length.
193    fn is_elementwise(&self) -> bool;
194}
195
196/// Arguments to a compute function invocation.
197#[derive(Clone)]
198pub struct InvocationArgs<'a> {
199    pub inputs: &'a [Input<'a>],
200    pub options: &'a dyn Options,
201}
202
203/// For unary compute functions, it's useful to just have this short-cut.
204pub struct UnaryArgs<'a, O: Options> {
205    pub array: &'a dyn Array,
206    pub options: &'a O,
207}
208
209impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
210    type Error = VortexError;
211
212    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
213        if value.inputs.len() != 1 {
214            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
215        }
216        let array = value.inputs[0]
217            .array()
218            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
219        let options =
220            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
221                vortex_err!("Expected options to be of type {}", type_name::<O>())
222            })?;
223        Ok(UnaryArgs { array, options })
224    }
225}
226
227/// For binary compute functions, it's useful to just have this short-cut.
228pub struct BinaryArgs<'a, O: Options> {
229    pub lhs: &'a dyn Array,
230    pub rhs: &'a dyn Array,
231    pub options: &'a O,
232}
233
234impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
235    type Error = VortexError;
236
237    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
238        if value.inputs.len() != 2 {
239            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
240        }
241        let lhs = value.inputs[0]
242            .array()
243            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
244        let rhs = value.inputs[1]
245            .array()
246            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
247        let options =
248            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
249                vortex_err!("Expected options to be of type {}", type_name::<O>())
250            })?;
251        Ok(BinaryArgs { lhs, rhs, options })
252    }
253}
254
255/// Input to a compute function.
256pub enum Input<'a> {
257    Scalar(&'a Scalar),
258    Array(&'a dyn Array),
259    Mask(&'a Mask),
260    Builder(&'a mut dyn ArrayBuilder),
261    DType(&'a DType),
262}
263
264impl Debug for Input<'_> {
265    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266        let mut f = f.debug_struct("Input");
267        match self {
268            Input::Scalar(scalar) => f.field("Scalar", scalar),
269            Input::Array(array) => f.field("Array", array),
270            Input::Mask(mask) => f.field("Mask", mask),
271            Input::Builder(builder) => f.field("Builder", &builder.len()),
272            Input::DType(dtype) => f.field("DType", dtype),
273        };
274        f.finish()
275    }
276}
277
278impl<'a> From<&'a dyn Array> for Input<'a> {
279    fn from(value: &'a dyn Array) -> Self {
280        Input::Array(value)
281    }
282}
283
284impl<'a> From<&'a Scalar> for Input<'a> {
285    fn from(value: &'a Scalar) -> Self {
286        Input::Scalar(value)
287    }
288}
289
290impl<'a> From<&'a Mask> for Input<'a> {
291    fn from(value: &'a Mask) -> Self {
292        Input::Mask(value)
293    }
294}
295
296impl<'a> From<&'a DType> for Input<'a> {
297    fn from(value: &'a DType) -> Self {
298        Input::DType(value)
299    }
300}
301
302impl<'a> Input<'a> {
303    pub fn scalar(&self) -> Option<&'a Scalar> {
304        match self {
305            Input::Scalar(scalar) => Some(*scalar),
306            _ => None,
307        }
308    }
309
310    pub fn array(&self) -> Option<&'a dyn Array> {
311        match self {
312            Input::Array(array) => Some(*array),
313            _ => None,
314        }
315    }
316
317    pub fn mask(&self) -> Option<&'a Mask> {
318        match self {
319            Input::Mask(mask) => Some(*mask),
320            _ => None,
321        }
322    }
323
324    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
325        match self {
326            Input::Builder(builder) => Some(*builder),
327            _ => None,
328        }
329    }
330
331    pub fn dtype(&self) -> Option<&'a DType> {
332        match self {
333            Input::DType(dtype) => Some(*dtype),
334            _ => None,
335        }
336    }
337}
338
339/// Output from a compute function.
340#[derive(Debug)]
341pub enum Output {
342    Scalar(Scalar),
343    Array(ArrayRef),
344}
345
346#[allow(clippy::len_without_is_empty)]
347impl Output {
348    pub fn dtype(&self) -> &DType {
349        match self {
350            Output::Scalar(scalar) => scalar.dtype(),
351            Output::Array(array) => array.dtype(),
352        }
353    }
354
355    pub fn len(&self) -> usize {
356        match self {
357            Output::Scalar(_) => 1,
358            Output::Array(array) => array.len(),
359        }
360    }
361
362    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
363        match self {
364            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
365            Output::Scalar(scalar) => Ok(scalar),
366        }
367    }
368
369    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
370        match self {
371            Output::Array(array) => Ok(array),
372            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
373        }
374    }
375}
376
377impl From<ArrayRef> for Output {
378    fn from(value: ArrayRef) -> Self {
379        Output::Array(value)
380    }
381}
382
383impl From<Scalar> for Output {
384    fn from(value: Scalar) -> Self {
385        Output::Scalar(value)
386    }
387}
388
389/// Options for a compute function invocation.
390pub trait Options: 'static {
391    fn as_any(&self) -> &dyn Any;
392}
393
394impl Options for () {
395    fn as_any(&self) -> &dyn Any {
396        self
397    }
398}
399
400/// Compute functions can ask arrays for compute kernels for a given invocation.
401///
402/// The kernel is invoked with the input arguments and options, and can return `None` if it is
403/// unable to compute the result for the given inputs due to missing implementation logic.
404/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
405/// is indicating that it cannot compute the result for the given inputs, and another kernel should
406/// be tried. *Not* that the given inputs are invalid for the compute function.
407///
408/// If the kernel fails to compute a result, it should return a `Some` with the error.
409pub trait Kernel: 'static + Send + Sync + Debug {
410    /// Invokes the kernel with the given input arguments and options.
411    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
412}
413
414/// Register a kernel for a compute function.
415/// See each compute function for the correct type of kernel to register.
416#[macro_export]
417macro_rules! register_kernel {
418    ($T:expr) => {
419        $crate::aliases::inventory::submit!($T);
420    };
421}