vortex_array/compute/
mod.rs

1//! Compute kernels on top of Vortex Arrays.
2//!
3//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
4//! and filter Vortex Arrays in their encoded forms.
5//!
6//! Every array encoding has the ability to implement their own efficient implementations of these
7//! operators, else we will decode, and perform the equivalent operator from Arrow.
8
9use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11
12use arcref::ArcRef;
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::*;
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::*;
24pub use list_contains::*;
25pub use mask::*;
26pub use min_max::*;
27pub use nan_count::*;
28pub use numeric::*;
29use parking_lot::RwLock;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list_contains;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62/// An instance of a compute function holding the implementation vtable and a set of registered
63/// compute kernels.
64pub struct ComputeFn {
65    id: ArcRef<str>,
66    vtable: ArcRef<dyn ComputeFnVTable>,
67    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71    /// Create a new compute function from the given [`ComputeFnVTable`].
72    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73        Self {
74            id,
75            vtable,
76            kernels: Default::default(),
77        }
78    }
79
80    /// Returns the string identifier of the compute function.
81    pub fn id(&self) -> &ArcRef<str> {
82        &self.id
83    }
84
85    /// Register a kernel for the compute function.
86    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87        self.kernels.write().push(kernel);
88    }
89
90    /// Invokes the compute function with the given arguments.
91    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
92        // Perform some pre-condition checks against the arguments and the function properties.
93        if self.is_elementwise() {
94            // For element-wise functions, all input arrays must be the same length.
95            if !args
96                .inputs
97                .iter()
98                .filter_map(|input| input.array())
99                .map(|array| array.len())
100                .all_equal()
101            {
102                vortex_bail!(
103                    "Compute function {} is elementwise but input arrays have different lengths",
104                    self.id
105                );
106            }
107        }
108
109        let expected_dtype = self.vtable.return_dtype(args)?;
110        let expected_len = self.vtable.return_len(args)?;
111
112        let output = self.vtable.invoke(args, &self.kernels.read())?;
113
114        if output.dtype() != &expected_dtype {
115            vortex_bail!(
116                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
117                self.id,
118                output.dtype(),
119                &expected_dtype,
120                args.inputs
121                    .iter()
122                    .filter_map(|input| input.array())
123                    .format_with(",", |array, f| f(&array.tree_display()))
124            );
125        }
126        if output.len() != expected_len {
127            vortex_bail!(
128                "Internal error: compute function {} returned a result of length {} but expected {}",
129                self.id,
130                output.len(),
131                expected_len
132            );
133        }
134
135        Ok(output)
136    }
137
138    /// Compute the return type of the function given the input arguments.
139    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
140        self.vtable.return_dtype(args)
141    }
142
143    /// Compute the return length of the function given the input arguments.
144    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
145        self.vtable.return_len(args)
146    }
147
148    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
149    pub fn is_elementwise(&self) -> bool {
150        // TODO(ngates): should this just be a constant passed in the constructor?
151        self.vtable.is_elementwise()
152    }
153
154    /// Returns the compute function's kernels.
155    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
156        self.kernels.read().to_vec()
157    }
158}
159
160/// VTable for the implementation of a compute function.
161pub trait ComputeFnVTable: 'static + Send + Sync {
162    /// Invokes the compute function entry-point with the given input arguments and options.
163    ///
164    /// The entry-point logic can short-circuit compute using statistics, update result array
165    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
166    /// to successfully compute a result.
167    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
168    -> VortexResult<Output>;
169
170    /// Computes the return type of the function given the input arguments.
171    ///
172    /// All kernel implementations will be validated to return the [`DType`] as computed here.
173    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
174
175    /// Computes the return length of the function given the input arguments.
176    ///
177    /// All kernel implementations will be validated to return the len as computed here.
178    /// Scalars are considered to have length 1.
179    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
180
181    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
182    /// input and no information is shared between elements.
183    ///
184    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
185    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
186    ///
187    /// All input arrays to an elementwise function *must* have the same length.
188    fn is_elementwise(&self) -> bool;
189}
190
191/// Arguments to a compute function invocation.
192#[derive(Clone)]
193pub struct InvocationArgs<'a> {
194    pub inputs: &'a [Input<'a>],
195    pub options: &'a dyn Options,
196}
197
198/// For unary compute functions, it's useful to just have this short-cut.
199pub struct UnaryArgs<'a, O: Options> {
200    pub array: &'a dyn Array,
201    pub options: &'a O,
202}
203
204impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
205    type Error = VortexError;
206
207    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
208        if value.inputs.len() != 1 {
209            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
210        }
211        let array = value.inputs[0]
212            .array()
213            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
214        let options =
215            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
216                vortex_err!("Expected options to be of type {}", type_name::<O>())
217            })?;
218        Ok(UnaryArgs { array, options })
219    }
220}
221
222/// For binary compute functions, it's useful to just have this short-cut.
223pub struct BinaryArgs<'a, O: Options> {
224    pub lhs: &'a dyn Array,
225    pub rhs: &'a dyn Array,
226    pub options: &'a O,
227}
228
229impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
230    type Error = VortexError;
231
232    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
233        if value.inputs.len() != 2 {
234            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
235        }
236        let lhs = value.inputs[0]
237            .array()
238            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
239        let rhs = value.inputs[1]
240            .array()
241            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
242        let options =
243            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
244                vortex_err!("Expected options to be of type {}", type_name::<O>())
245            })?;
246        Ok(BinaryArgs { lhs, rhs, options })
247    }
248}
249
250/// Input to a compute function.
251pub enum Input<'a> {
252    Scalar(&'a Scalar),
253    Array(&'a dyn Array),
254    Mask(&'a Mask),
255    Builder(&'a mut dyn ArrayBuilder),
256    DType(&'a DType),
257}
258
259impl Debug for Input<'_> {
260    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
261        let mut f = f.debug_struct("Input");
262        match self {
263            Input::Scalar(scalar) => f.field("Scalar", scalar),
264            Input::Array(array) => f.field("Array", array),
265            Input::Mask(mask) => f.field("Mask", mask),
266            Input::Builder(builder) => f.field("Builder", &builder.len()),
267            Input::DType(dtype) => f.field("DType", dtype),
268        };
269        f.finish()
270    }
271}
272
273impl<'a> From<&'a dyn Array> for Input<'a> {
274    fn from(value: &'a dyn Array) -> Self {
275        Input::Array(value)
276    }
277}
278
279impl<'a> From<&'a Scalar> for Input<'a> {
280    fn from(value: &'a Scalar) -> Self {
281        Input::Scalar(value)
282    }
283}
284
285impl<'a> From<&'a Mask> for Input<'a> {
286    fn from(value: &'a Mask) -> Self {
287        Input::Mask(value)
288    }
289}
290
291impl<'a> From<&'a DType> for Input<'a> {
292    fn from(value: &'a DType) -> Self {
293        Input::DType(value)
294    }
295}
296
297impl<'a> Input<'a> {
298    pub fn scalar(&self) -> Option<&'a Scalar> {
299        match self {
300            Input::Scalar(scalar) => Some(*scalar),
301            _ => None,
302        }
303    }
304
305    pub fn array(&self) -> Option<&'a dyn Array> {
306        match self {
307            Input::Array(array) => Some(*array),
308            _ => None,
309        }
310    }
311
312    pub fn mask(&self) -> Option<&'a Mask> {
313        match self {
314            Input::Mask(mask) => Some(*mask),
315            _ => None,
316        }
317    }
318
319    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
320        match self {
321            Input::Builder(builder) => Some(*builder),
322            _ => None,
323        }
324    }
325
326    pub fn dtype(&self) -> Option<&'a DType> {
327        match self {
328            Input::DType(dtype) => Some(*dtype),
329            _ => None,
330        }
331    }
332}
333
334/// Output from a compute function.
335#[derive(Debug)]
336pub enum Output {
337    Scalar(Scalar),
338    Array(ArrayRef),
339}
340
341#[allow(clippy::len_without_is_empty)]
342impl Output {
343    pub fn dtype(&self) -> &DType {
344        match self {
345            Output::Scalar(scalar) => scalar.dtype(),
346            Output::Array(array) => array.dtype(),
347        }
348    }
349
350    pub fn len(&self) -> usize {
351        match self {
352            Output::Scalar(_) => 1,
353            Output::Array(array) => array.len(),
354        }
355    }
356
357    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
358        match self {
359            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
360            Output::Scalar(scalar) => Ok(scalar),
361        }
362    }
363
364    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
365        match self {
366            Output::Array(array) => Ok(array),
367            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
368        }
369    }
370}
371
372impl From<ArrayRef> for Output {
373    fn from(value: ArrayRef) -> Self {
374        Output::Array(value)
375    }
376}
377
378impl From<Scalar> for Output {
379    fn from(value: Scalar) -> Self {
380        Output::Scalar(value)
381    }
382}
383
384/// Options for a compute function invocation.
385pub trait Options: 'static {
386    fn as_any(&self) -> &dyn Any;
387}
388
389impl Options for () {
390    fn as_any(&self) -> &dyn Any {
391        self
392    }
393}
394
395/// Compute functions can ask arrays for compute kernels for a given invocation.
396///
397/// The kernel is invoked with the input arguments and options, and can return `None` if it is
398/// unable to compute the result for the given inputs due to missing implementation logic.
399/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
400/// is indicating that it cannot compute the result for the given inputs, and another kernel should
401/// be tried. *Not* that the given inputs are invalid for the compute function.
402///
403/// If the kernel fails to compute a result, it should return a `Some` with the error.
404pub trait Kernel: 'static + Send + Sync + Debug {
405    /// Invokes the kernel with the given input arguments and options.
406    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
407}
408
409/// Register a kernel for a compute function.
410/// See each compute function for the correct type of kernel to register.
411#[macro_export]
412macro_rules! register_kernel {
413    ($T:expr) => {
414        $crate::aliases::inventory::submit!($T);
415    };
416}