Skip to main content

vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use boolean::*;
19#[expect(deprecated)]
20pub use cast::cast;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24#[expect(deprecated)]
25pub use invert::invert;
26pub use is_constant::*;
27pub use is_sorted::*;
28use itertools::Itertools;
29#[expect(deprecated)]
30pub use list_contains::list_contains;
31pub use mask::*;
32pub use min_max::*;
33pub use nan_count::*;
34pub use numeric::*;
35use parking_lot::RwLock;
36pub use sum::*;
37use vortex_error::VortexError;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_err;
41use vortex_mask::Mask;
42pub use zip::*;
43
44use crate::Array;
45use crate::ArrayRef;
46use crate::builders::ArrayBuilder;
47use crate::dtype::DType;
48use crate::scalar::Scalar;
49
50#[cfg(feature = "arbitrary")]
51mod arbitrary;
52mod boolean;
53mod cast;
54mod compare;
55#[cfg(feature = "_test-harness")]
56pub mod conformance;
57mod fill_null;
58mod filter;
59mod invert;
60mod is_constant;
61mod is_sorted;
62mod list_contains;
63mod mask;
64mod min_max;
65mod nan_count;
66mod numeric;
67mod sum;
68mod zip;
69
70/// An instance of a compute function holding the implementation vtable and a set of registered
71/// compute kernels.
72pub struct ComputeFn {
73    id: ArcRef<str>,
74    vtable: ArcRef<dyn ComputeFnVTable>,
75    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
76}
77
78/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
79///
80/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
81pub fn warm_up_vtables() {
82    #[allow(unused_qualifications)]
83    is_constant::warm_up_vtable();
84    is_sorted::warm_up_vtable();
85    min_max::warm_up_vtable();
86    nan_count::warm_up_vtable();
87    sum::warm_up_vtable();
88}
89
90impl ComputeFn {
91    /// Create a new compute function from the given [`ComputeFnVTable`].
92    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
93        Self {
94            id,
95            vtable,
96            kernels: Default::default(),
97        }
98    }
99
100    /// Returns the string identifier of the compute function.
101    pub fn id(&self) -> &ArcRef<str> {
102        &self.id
103    }
104
105    /// Register a kernel for the compute function.
106    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
107        self.kernels.write().push(kernel);
108    }
109
110    /// Invokes the compute function with the given arguments.
111    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
112        // Perform some pre-condition checks against the arguments and the function properties.
113        if self.is_elementwise() {
114            // For element-wise functions, all input arrays must be the same length.
115            if !args
116                .inputs
117                .iter()
118                .filter_map(|input| input.array())
119                .map(|array| array.len())
120                .all_equal()
121            {
122                vortex_bail!(
123                    "Compute function {} is elementwise but input arrays have different lengths",
124                    self.id
125                );
126            }
127        }
128
129        let expected_dtype = self.vtable.return_dtype(args)?;
130        let expected_len = self.vtable.return_len(args)?;
131
132        let output = self.vtable.invoke(args, &self.kernels.read())?;
133
134        if output.dtype() != &expected_dtype {
135            vortex_bail!(
136                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
137                self.id,
138                output.dtype(),
139                &expected_dtype,
140                args.inputs
141                    .iter()
142                    .filter_map(|input| input.array())
143                    .format_with(",", |array, f| f(&array.encoding_id()))
144            );
145        }
146        if output.len() != expected_len {
147            vortex_bail!(
148                "Internal error: compute function {} returned a result of length {} but expected {}",
149                self.id,
150                output.len(),
151                expected_len
152            );
153        }
154
155        Ok(output)
156    }
157
158    /// Compute the return type of the function given the input arguments.
159    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
160        self.vtable.return_dtype(args)
161    }
162
163    /// Compute the return length of the function given the input arguments.
164    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
165        self.vtable.return_len(args)
166    }
167
168    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
169    pub fn is_elementwise(&self) -> bool {
170        // TODO(ngates): should this just be a constant passed in the constructor?
171        self.vtable.is_elementwise()
172    }
173
174    /// Returns the compute function's kernels.
175    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
176        self.kernels.read().to_vec()
177    }
178}
179
180/// VTable for the implementation of a compute function.
181pub trait ComputeFnVTable: 'static + Send + Sync {
182    /// Invokes the compute function entry-point with the given input arguments and options.
183    ///
184    /// The entry-point logic can short-circuit compute using statistics, update result array
185    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
186    /// to successfully compute a result.
187    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
188    -> VortexResult<Output>;
189
190    /// Computes the return type of the function given the input arguments.
191    ///
192    /// All kernel implementations will be validated to return the [`DType`] as computed here.
193    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
194
195    /// Computes the return length of the function given the input arguments.
196    ///
197    /// All kernel implementations will be validated to return the len as computed here.
198    /// Scalars are considered to have length 1.
199    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
200
201    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
202    /// input and no information is shared between elements.
203    ///
204    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
205    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
206    ///
207    /// All input arrays to an elementwise function *must* have the same length.
208    fn is_elementwise(&self) -> bool;
209}
210
211/// Arguments to a compute function invocation.
212#[derive(Clone)]
213pub struct InvocationArgs<'a> {
214    pub inputs: &'a [Input<'a>],
215    pub options: &'a dyn Options,
216}
217
218/// For unary compute functions, it's useful to just have this short-cut.
219pub struct UnaryArgs<'a, O: Options> {
220    pub array: &'a dyn Array,
221    pub options: &'a O,
222}
223
224impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
225    type Error = VortexError;
226
227    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
228        if value.inputs.len() != 1 {
229            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
230        }
231        let array = value.inputs[0]
232            .array()
233            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
234        let options =
235            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
236                vortex_err!("Expected options to be of type {}", type_name::<O>())
237            })?;
238        Ok(UnaryArgs { array, options })
239    }
240}
241
242/// For binary compute functions, it's useful to just have this short-cut.
243pub struct BinaryArgs<'a, O: Options> {
244    pub lhs: &'a dyn Array,
245    pub rhs: &'a dyn Array,
246    pub options: &'a O,
247}
248
249impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
250    type Error = VortexError;
251
252    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
253        if value.inputs.len() != 2 {
254            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
255        }
256        let lhs = value.inputs[0]
257            .array()
258            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
259        let rhs = value.inputs[1]
260            .array()
261            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
262        let options =
263            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
264                vortex_err!("Expected options to be of type {}", type_name::<O>())
265            })?;
266        Ok(BinaryArgs { lhs, rhs, options })
267    }
268}
269
270/// Input to a compute function.
271pub enum Input<'a> {
272    Scalar(&'a Scalar),
273    Array(&'a dyn Array),
274    Mask(&'a Mask),
275    Builder(&'a mut dyn ArrayBuilder),
276    DType(&'a DType),
277}
278
279impl Debug for Input<'_> {
280    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281        let mut f = f.debug_struct("Input");
282        match self {
283            Input::Scalar(scalar) => f.field("Scalar", scalar),
284            Input::Array(array) => f.field("Array", array),
285            Input::Mask(mask) => f.field("Mask", mask),
286            Input::Builder(builder) => f.field("Builder", &builder.len()),
287            Input::DType(dtype) => f.field("DType", dtype),
288        };
289        f.finish()
290    }
291}
292
293impl<'a> From<&'a dyn Array> for Input<'a> {
294    fn from(value: &'a dyn Array) -> Self {
295        Input::Array(value)
296    }
297}
298
299impl<'a> From<&'a Scalar> for Input<'a> {
300    fn from(value: &'a Scalar) -> Self {
301        Input::Scalar(value)
302    }
303}
304
305impl<'a> From<&'a Mask> for Input<'a> {
306    fn from(value: &'a Mask) -> Self {
307        Input::Mask(value)
308    }
309}
310
311impl<'a> From<&'a DType> for Input<'a> {
312    fn from(value: &'a DType) -> Self {
313        Input::DType(value)
314    }
315}
316
317impl<'a> Input<'a> {
318    pub fn scalar(&self) -> Option<&'a Scalar> {
319        if let Input::Scalar(scalar) = self {
320            Some(*scalar)
321        } else {
322            None
323        }
324    }
325
326    pub fn array(&self) -> Option<&'a dyn Array> {
327        if let Input::Array(array) = self {
328            Some(*array)
329        } else {
330            None
331        }
332    }
333
334    pub fn mask(&self) -> Option<&'a Mask> {
335        if let Input::Mask(mask) = self {
336            Some(*mask)
337        } else {
338            None
339        }
340    }
341
342    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
343        if let Input::Builder(builder) = self {
344            Some(*builder)
345        } else {
346            None
347        }
348    }
349
350    pub fn dtype(&self) -> Option<&'a DType> {
351        if let Input::DType(dtype) = self {
352            Some(*dtype)
353        } else {
354            None
355        }
356    }
357}
358
359/// Output from a compute function.
360#[derive(Debug)]
361pub enum Output {
362    Scalar(Scalar),
363    Array(ArrayRef),
364}
365
366#[expect(
367    clippy::len_without_is_empty,
368    reason = "Output is always non-empty (scalar has len 1)"
369)]
370impl Output {
371    pub fn dtype(&self) -> &DType {
372        match self {
373            Output::Scalar(scalar) => scalar.dtype(),
374            Output::Array(array) => array.dtype(),
375        }
376    }
377
378    pub fn len(&self) -> usize {
379        match self {
380            Output::Scalar(_) => 1,
381            Output::Array(array) => array.len(),
382        }
383    }
384
385    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
386        match self {
387            Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
388            Output::Scalar(scalar) => Ok(scalar),
389        }
390    }
391
392    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
393        match self {
394            Output::Array(array) => Ok(array),
395            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
396        }
397    }
398}
399
400impl From<ArrayRef> for Output {
401    fn from(value: ArrayRef) -> Self {
402        Output::Array(value)
403    }
404}
405
406impl From<Scalar> for Output {
407    fn from(value: Scalar) -> Self {
408        Output::Scalar(value)
409    }
410}
411
412/// Options for a compute function invocation.
413pub trait Options: 'static {
414    fn as_any(&self) -> &dyn Any;
415}
416
417impl Options for () {
418    fn as_any(&self) -> &dyn Any {
419        self
420    }
421}
422
423/// Compute functions can ask arrays for compute kernels for a given invocation.
424///
425/// The kernel is invoked with the input arguments and options, and can return `None` if it is
426/// unable to compute the result for the given inputs due to missing implementation logic.
427/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
428/// is indicating that it cannot compute the result for the given inputs, and another kernel should
429/// be tried. *Not* that the given inputs are invalid for the compute function.
430///
431/// If the kernel fails to compute a result, it should return a `Some` with the error.
432pub trait Kernel: 'static + Send + Sync + Debug {
433    /// Invokes the kernel with the given input arguments and options.
434    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
435}
436
437/// Register a kernel for a compute function.
438/// See each compute function for the correct type of kernel to register.
439#[macro_export]
440macro_rules! register_kernel {
441    ($T:expr) => {
442        $crate::aliases::inventory::submit!($T);
443    };
444}