Skip to main content

vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use is_constant::*;
19pub use is_sorted::*;
20use itertools::Itertools;
21pub use min_max::*;
22pub use nan_count::*;
23use parking_lot::RwLock;
24pub use sum::*;
25use vortex_error::VortexError;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29use vortex_mask::Mask;
30
31use crate::Array;
32use crate::ArrayRef;
33use crate::builders::ArrayBuilder;
34use crate::dtype::DType;
35use crate::scalar::Scalar;
36
37#[cfg(feature = "arbitrary")]
38mod arbitrary;
39#[cfg(feature = "_test-harness")]
40pub mod conformance;
41mod is_constant;
42mod is_sorted;
43mod min_max;
44mod nan_count;
45mod sum;
46
47/// An instance of a compute function holding the implementation vtable and a set of registered
48/// compute kernels.
49pub struct ComputeFn {
50    id: ArcRef<str>,
51    vtable: ArcRef<dyn ComputeFnVTable>,
52    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
53}
54
55/// Force all the default [`ComputeFn`] vtables to register all available compute kernels.
56///
57/// Mostly useful for small benchmarks where the overhead might cause noise depending on the order of benchmarks.
58pub fn warm_up_vtables() {
59    #[allow(unused_qualifications)]
60    is_constant::warm_up_vtable();
61    is_sorted::warm_up_vtable();
62    min_max::warm_up_vtable();
63    nan_count::warm_up_vtable();
64    sum::warm_up_vtable();
65}
66
67impl ComputeFn {
68    /// Create a new compute function from the given [`ComputeFnVTable`].
69    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
70        Self {
71            id,
72            vtable,
73            kernels: Default::default(),
74        }
75    }
76
77    /// Returns the string identifier of the compute function.
78    pub fn id(&self) -> &ArcRef<str> {
79        &self.id
80    }
81
82    /// Register a kernel for the compute function.
83    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
84        self.kernels.write().push(kernel);
85    }
86
87    /// Invokes the compute function with the given arguments.
88    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
89        // Perform some pre-condition checks against the arguments and the function properties.
90        if self.is_elementwise() {
91            // For element-wise functions, all input arrays must be the same length.
92            if !args
93                .inputs
94                .iter()
95                .filter_map(|input| input.array())
96                .map(|array| array.len())
97                .all_equal()
98            {
99                vortex_bail!(
100                    "Compute function {} is elementwise but input arrays have different lengths",
101                    self.id
102                );
103            }
104        }
105
106        let expected_dtype = self.vtable.return_dtype(args)?;
107        let expected_len = self.vtable.return_len(args)?;
108
109        let output = self.vtable.invoke(args, &self.kernels.read())?;
110
111        if output.dtype() != &expected_dtype {
112            vortex_bail!(
113                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
114                self.id,
115                output.dtype(),
116                &expected_dtype,
117                args.inputs
118                    .iter()
119                    .filter_map(|input| input.array())
120                    .format_with(",", |array, f| f(&array.encoding_id()))
121            );
122        }
123        if output.len() != expected_len {
124            vortex_bail!(
125                "Internal error: compute function {} returned a result of length {} but expected {}",
126                self.id,
127                output.len(),
128                expected_len
129            );
130        }
131
132        Ok(output)
133    }
134
135    /// Compute the return type of the function given the input arguments.
136    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
137        self.vtable.return_dtype(args)
138    }
139
140    /// Compute the return length of the function given the input arguments.
141    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
142        self.vtable.return_len(args)
143    }
144
145    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
146    pub fn is_elementwise(&self) -> bool {
147        // TODO(ngates): should this just be a constant passed in the constructor?
148        self.vtable.is_elementwise()
149    }
150
151    /// Returns the compute function's kernels.
152    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
153        self.kernels.read().to_vec()
154    }
155}
156
157/// VTable for the implementation of a compute function.
158pub trait ComputeFnVTable: 'static + Send + Sync {
159    /// Invokes the compute function entry-point with the given input arguments and options.
160    ///
161    /// The entry-point logic can short-circuit compute using statistics, update result array
162    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
163    /// to successfully compute a result.
164    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
165    -> VortexResult<Output>;
166
167    /// Computes the return type of the function given the input arguments.
168    ///
169    /// All kernel implementations will be validated to return the [`DType`] as computed here.
170    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
171
172    /// Computes the return length of the function given the input arguments.
173    ///
174    /// All kernel implementations will be validated to return the len as computed here.
175    /// Scalars are considered to have length 1.
176    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
177
178    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
179    /// input and no information is shared between elements.
180    ///
181    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
182    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
183    ///
184    /// All input arrays to an elementwise function *must* have the same length.
185    fn is_elementwise(&self) -> bool;
186}
187
188/// Arguments to a compute function invocation.
189#[derive(Clone)]
190pub struct InvocationArgs<'a> {
191    pub inputs: &'a [Input<'a>],
192    pub options: &'a dyn Options,
193}
194
195/// For unary compute functions, it's useful to just have this short-cut.
196pub struct UnaryArgs<'a, O: Options> {
197    pub array: &'a dyn Array,
198    pub options: &'a O,
199}
200
201impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
202    type Error = VortexError;
203
204    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
205        if value.inputs.len() != 1 {
206            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
207        }
208        let array = value.inputs[0]
209            .array()
210            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
211        let options =
212            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
213                vortex_err!("Expected options to be of type {}", type_name::<O>())
214            })?;
215        Ok(UnaryArgs { array, options })
216    }
217}
218
219/// For binary compute functions, it's useful to just have this short-cut.
220pub struct BinaryArgs<'a, O: Options> {
221    pub lhs: &'a dyn Array,
222    pub rhs: &'a dyn Array,
223    pub options: &'a O,
224}
225
226impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
227    type Error = VortexError;
228
229    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
230        if value.inputs.len() != 2 {
231            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
232        }
233        let lhs = value.inputs[0]
234            .array()
235            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
236        let rhs = value.inputs[1]
237            .array()
238            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
239        let options =
240            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
241                vortex_err!("Expected options to be of type {}", type_name::<O>())
242            })?;
243        Ok(BinaryArgs { lhs, rhs, options })
244    }
245}
246
247/// Input to a compute function.
248pub enum Input<'a> {
249    Scalar(&'a Scalar),
250    Array(&'a dyn Array),
251    Mask(&'a Mask),
252    Builder(&'a mut dyn ArrayBuilder),
253    DType(&'a DType),
254}
255
256impl Debug for Input<'_> {
257    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
258        let mut f = f.debug_struct("Input");
259        match self {
260            Input::Scalar(scalar) => f.field("Scalar", scalar),
261            Input::Array(array) => f.field("Array", array),
262            Input::Mask(mask) => f.field("Mask", mask),
263            Input::Builder(builder) => f.field("Builder", &builder.len()),
264            Input::DType(dtype) => f.field("DType", dtype),
265        };
266        f.finish()
267    }
268}
269
270impl<'a> From<&'a dyn Array> for Input<'a> {
271    fn from(value: &'a dyn Array) -> Self {
272        Input::Array(value)
273    }
274}
275
276impl<'a> From<&'a ArrayRef> for Input<'a> {
277    fn from(value: &'a ArrayRef) -> Self {
278        Input::Array(value.as_ref())
279    }
280}
281
282impl<'a> From<&'a Scalar> for Input<'a> {
283    fn from(value: &'a Scalar) -> Self {
284        Input::Scalar(value)
285    }
286}
287
288impl<'a> From<&'a Mask> for Input<'a> {
289    fn from(value: &'a Mask) -> Self {
290        Input::Mask(value)
291    }
292}
293
294impl<'a> From<&'a DType> for Input<'a> {
295    fn from(value: &'a DType) -> Self {
296        Input::DType(value)
297    }
298}
299
300impl<'a> Input<'a> {
301    pub fn scalar(&self) -> Option<&'a Scalar> {
302        if let Input::Scalar(scalar) = self {
303            Some(*scalar)
304        } else {
305            None
306        }
307    }
308
309    pub fn array(&self) -> Option<&'a dyn Array> {
310        if let Input::Array(array) = self {
311            Some(*array)
312        } else {
313            None
314        }
315    }
316
317    pub fn mask(&self) -> Option<&'a Mask> {
318        if let Input::Mask(mask) = self {
319            Some(*mask)
320        } else {
321            None
322        }
323    }
324
325    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
326        if let Input::Builder(builder) = self {
327            Some(*builder)
328        } else {
329            None
330        }
331    }
332
333    pub fn dtype(&self) -> Option<&'a DType> {
334        if let Input::DType(dtype) = self {
335            Some(*dtype)
336        } else {
337            None
338        }
339    }
340}
341
342/// Output from a compute function.
343#[derive(Debug)]
344pub enum Output {
345    Scalar(Scalar),
346    Array(ArrayRef),
347}
348
349#[expect(
350    clippy::len_without_is_empty,
351    reason = "Output is always non-empty (scalar has len 1)"
352)]
353impl Output {
354    pub fn dtype(&self) -> &DType {
355        match self {
356            Output::Scalar(scalar) => scalar.dtype(),
357            Output::Array(array) => array.dtype(),
358        }
359    }
360
361    pub fn len(&self) -> usize {
362        match self {
363            Output::Scalar(_) => 1,
364            Output::Array(array) => array.len(),
365        }
366    }
367
368    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
369        match self {
370            Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
371            Output::Scalar(scalar) => Ok(scalar),
372        }
373    }
374
375    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
376        match self {
377            Output::Array(array) => Ok(array),
378            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
379        }
380    }
381}
382
383impl From<ArrayRef> for Output {
384    fn from(value: ArrayRef) -> Self {
385        Output::Array(value)
386    }
387}
388
389impl From<Scalar> for Output {
390    fn from(value: Scalar) -> Self {
391        Output::Scalar(value)
392    }
393}
394
395/// Options for a compute function invocation.
396pub trait Options: 'static {
397    fn as_any(&self) -> &dyn Any;
398}
399
400impl Options for () {
401    fn as_any(&self) -> &dyn Any {
402        self
403    }
404}
405
406/// Compute functions can ask arrays for compute kernels for a given invocation.
407///
408/// The kernel is invoked with the input arguments and options, and can return `None` if it is
409/// unable to compute the result for the given inputs due to missing implementation logic.
410/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
411/// is indicating that it cannot compute the result for the given inputs, and another kernel should
412/// be tried. *Not* that the given inputs are invalid for the compute function.
413///
414/// If the kernel fails to compute a result, it should return a `Some` with the error.
415pub trait Kernel: 'static + Send + Sync + Debug {
416    /// Invokes the kernel with the given input arguments and options.
417    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
418}
419
420/// Register a kernel for a compute function.
421/// See each compute function for the correct type of kernel to register.
422#[macro_export]
423macro_rules! register_kernel {
424    ($T:expr) => {
425        $crate::aliases::inventory::submit!($T);
426    };
427}