vortex_array/compute/
mod.rs

1//! Compute kernels on top of Vortex Arrays.
2//!
3//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
4//! and filter Vortex Arrays in their encoded forms.
5//!
6//! Every array encoding has the ability to implement their own efficient implementations of these
7//! operators, else we will decode, and perform the equivalent operator from Arrow.
8
9use std::any::Any;
10use std::fmt::{Debug, Formatter};
11use std::sync::RwLock;
12
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::{FillNullFn, fill_null};
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::{LikeFn, LikeOptions, like};
24pub use mask::*;
25pub use min_max::{MinMaxFn, MinMaxResult, min_max};
26pub use numeric::*;
27pub use optimize::*;
28pub use scalar_at::{ScalarAtFn, scalar_at};
29pub use search_sorted::*;
30pub use slice::{SliceFn, slice};
31pub use sum::*;
32pub use take::{TakeFn, take, take_into};
33pub use take_from::TakeFromFn;
34pub use to_arrow::*;
35pub use uncompressed_size::*;
36use vortex_dtype::DType;
37use vortex_error::{VortexExpect, VortexResult, vortex_bail};
38use vortex_mask::Mask;
39use vortex_scalar::Scalar;
40
41use crate::arcref::ArcRef;
42use crate::builders::ArrayBuilder;
43use crate::{Array, ArrayRef};
44
45#[cfg(feature = "arbitrary")]
46mod arbitrary;
47mod between;
48mod boolean;
49mod cast;
50mod compare;
51#[cfg(feature = "test-harness")]
52pub mod conformance;
53mod fill_null;
54mod filter;
55mod invert;
56mod is_constant;
57mod is_sorted;
58mod like;
59mod mask;
60mod min_max;
61mod numeric;
62mod optimize;
63mod scalar_at;
64mod search_sorted;
65mod slice;
66mod sum;
67mod take;
68mod take_from;
69mod to_arrow;
70mod uncompressed_size;
71
72/// An instance of a compute function holding the implementation vtable and a set of registered
73/// compute kernels.
74pub struct ComputeFn {
75    id: ArcRef<str>,
76    vtable: ArcRef<dyn ComputeFnVTable>,
77    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
78}
79
80impl ComputeFn {
81    /// Create a new compute function from the given [`ComputeFnVTable`].
82    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
83        Self {
84            id,
85            vtable,
86            kernels: Default::default(),
87        }
88    }
89
90    /// Returns the string identifier of the compute function.
91    pub fn id(&self) -> &ArcRef<str> {
92        &self.id
93    }
94
95    /// Register a kernel for the compute function.
96    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
97        self.kernels
98            .write()
99            .vortex_expect("poisoned lock")
100            .push(kernel);
101    }
102
103    /// Invokes the compute function with the given arguments.
104    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
105        // Perform some pre-condition checks against the arguments and the function properties.
106        if self.is_elementwise() {
107            // For element-wise functions, all input arrays must be the same length.
108            if !args
109                .inputs
110                .iter()
111                .filter_map(|input| input.array())
112                .map(|array| array.len())
113                .all_equal()
114            {
115                vortex_bail!(
116                    "Compute function {} is elementwise but input arrays have different lengths",
117                    self.id
118                );
119            }
120        }
121
122        let expected_dtype = self.vtable.return_dtype(args)?;
123        let expected_len = self.vtable.return_len(args)?;
124
125        let output = self
126            .vtable
127            .invoke(args, &self.kernels.read().vortex_expect("poisoned lock"))?;
128
129        if output.dtype() != &expected_dtype {
130            vortex_bail!(
131                "Internal error: compute function {} returned a result of type {} but expected {}",
132                self.id,
133                output.dtype(),
134                &expected_dtype
135            );
136        }
137        if output.len() != expected_len {
138            vortex_bail!(
139                "Internal error: compute function {} returned a result of length {} but expected {}",
140                self.id,
141                output.len(),
142                expected_len
143            );
144        }
145
146        Ok(output)
147    }
148
149    /// Compute the return type of the function given the input arguments.
150    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
151        self.vtable.return_dtype(args)
152    }
153
154    /// Compute the return length of the function given the input arguments.
155    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
156        self.vtable.return_len(args)
157    }
158
159    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
160    pub fn is_elementwise(&self) -> bool {
161        // TODO(ngates): should this just be a constant passed in the constructor?
162        self.vtable.is_elementwise()
163    }
164}
165
166/// VTable for the implementation of a compute function.
167pub trait ComputeFnVTable: 'static + Send + Sync {
168    /// Invokes the compute function entry-point with the given input arguments and options.
169    ///
170    /// The entry-point logic can short-circuit compute using statistics, update result array
171    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
172    /// to successfully compute a result.
173    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
174    -> VortexResult<Output>;
175
176    /// Computes the return type of the function given the input arguments.
177    ///
178    /// All kernel implementations will be validated to return the [`DType`] as computed here.
179    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
180
181    /// Computes the return length of the function given the input arguments.
182    ///
183    /// All kernel implementations will be validated to return the len as computed here.
184    /// Scalars are considered to have length 1.
185    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
186
187    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
188    /// input and no information is shared between elements.
189    ///
190    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
191    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
192    ///
193    /// All input arrays to an elementwise function *must* have the same length.
194    fn is_elementwise(&self) -> bool;
195}
196
197/// Arguments to a compute function invocation.
198#[derive(Clone)]
199pub struct InvocationArgs<'a> {
200    pub inputs: &'a [Input<'a>],
201    pub options: &'a dyn Options,
202}
203
204/// Input to a compute function.
205pub enum Input<'a> {
206    Scalar(&'a Scalar),
207    Array(&'a dyn Array),
208    Mask(&'a Mask),
209    Builder(&'a mut dyn ArrayBuilder),
210    DType(&'a DType),
211}
212
213impl Debug for Input<'_> {
214    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215        let mut f = f.debug_struct("Input");
216        match self {
217            Input::Scalar(scalar) => f.field("Scalar", scalar),
218            Input::Array(array) => f.field("Array", array),
219            Input::Mask(mask) => f.field("Mask", mask),
220            Input::Builder(builder) => f.field("Builder", &builder.len()),
221            Input::DType(dtype) => f.field("DType", dtype),
222        };
223        f.finish()
224    }
225}
226
227impl<'a> From<&'a dyn Array> for Input<'a> {
228    fn from(value: &'a dyn Array) -> Self {
229        Input::Array(value)
230    }
231}
232
233impl<'a> From<&'a Scalar> for Input<'a> {
234    fn from(value: &'a Scalar) -> Self {
235        Input::Scalar(value)
236    }
237}
238
239impl<'a> From<&'a Mask> for Input<'a> {
240    fn from(value: &'a Mask) -> Self {
241        Input::Mask(value)
242    }
243}
244
245impl<'a> From<&'a DType> for Input<'a> {
246    fn from(value: &'a DType) -> Self {
247        Input::DType(value)
248    }
249}
250
251impl<'a> Input<'a> {
252    pub fn scalar(&self) -> Option<&'a Scalar> {
253        match self {
254            Input::Scalar(scalar) => Some(*scalar),
255            _ => None,
256        }
257    }
258
259    pub fn array(&self) -> Option<&'a dyn Array> {
260        match self {
261            Input::Array(array) => Some(*array),
262            _ => None,
263        }
264    }
265
266    pub fn mask(&self) -> Option<&'a Mask> {
267        match self {
268            Input::Mask(mask) => Some(*mask),
269            _ => None,
270        }
271    }
272
273    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
274        match self {
275            Input::Builder(builder) => Some(*builder),
276            _ => None,
277        }
278    }
279
280    pub fn dtype(&self) -> Option<&'a DType> {
281        match self {
282            Input::DType(dtype) => Some(*dtype),
283            _ => None,
284        }
285    }
286}
287
288/// Output from a compute function.
289#[derive(Debug)]
290pub enum Output {
291    Scalar(Scalar),
292    Array(ArrayRef),
293}
294
295#[allow(clippy::len_without_is_empty)]
296impl Output {
297    pub fn dtype(&self) -> &DType {
298        match self {
299            Output::Scalar(scalar) => scalar.dtype(),
300            Output::Array(array) => array.dtype(),
301        }
302    }
303
304    pub fn len(&self) -> usize {
305        match self {
306            Output::Scalar(_) => 1,
307            Output::Array(array) => array.len(),
308        }
309    }
310
311    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
312        match &self {
313            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
314            Output::Scalar(scalar) => Ok(scalar.clone()),
315        }
316    }
317
318    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
319        match &self {
320            Output::Array(array) => Ok(array.clone()),
321            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
322        }
323    }
324}
325
326impl From<ArrayRef> for Output {
327    fn from(value: ArrayRef) -> Self {
328        Output::Array(value)
329    }
330}
331
332impl From<Scalar> for Output {
333    fn from(value: Scalar) -> Self {
334        Output::Scalar(value)
335    }
336}
337
338/// Options for a compute function invocation.
339pub trait Options {
340    fn as_any(&self) -> &dyn Any;
341}
342
343impl Options for () {
344    fn as_any(&self) -> &dyn Any {
345        self
346    }
347}
348
349/// Compute functions can ask arrays for compute kernels for a given invocation.
350///
351/// The kernel is invoked with the input arguments and options, and can return `None` if it is
352/// unable to compute the result for the given inputs due to missing implementation logic.
353/// For example, if kernel doesn't support the `LTE` operator.
354///
355/// If the kernel fails to compute a result, it should return a `Some` with the error.
356pub trait Kernel: 'static + Send + Sync + Debug {
357    /// Invokes the kernel with the given input arguments and options.
358    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
359}
360
361/// Register a kernel for a compute function.
362/// See each compute function for the correct type of kernel to register.
363#[macro_export]
364macro_rules! register_kernel {
365    ($T:expr) => {
366        $crate::aliases::inventory::submit!($T);
367    };
368}