vortex_array/compute/
mod.rs

1//! Compute kernels on top of Vortex Arrays.
2//!
3//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
4//! and filter Vortex Arrays in their encoded forms.
5//!
6//! Every array encoding has the ability to implement their own efficient implementations of these
7//! operators, else we will decode, and perform the equivalent operator from Arrow.
8
9use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11
12use arcref::ArcRef;
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::*;
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::*;
24pub use list::*;
25pub use mask::*;
26pub use min_max::*;
27pub use nan_count::*;
28pub use numeric::*;
29use parking_lot::RwLock;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62/// An instance of a compute function holding the implementation vtable and a set of registered
63/// compute kernels.
64pub struct ComputeFn {
65    id: ArcRef<str>,
66    vtable: ArcRef<dyn ComputeFnVTable>,
67    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71    /// Create a new compute function from the given [`ComputeFnVTable`].
72    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73        Self {
74            id,
75            vtable,
76            kernels: Default::default(),
77        }
78    }
79
80    /// Returns the string identifier of the compute function.
81    pub fn id(&self) -> &ArcRef<str> {
82        &self.id
83    }
84
85    /// Register a kernel for the compute function.
86    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87        self.kernels.write().push(kernel);
88    }
89
90    /// Invokes the compute function with the given arguments.
91    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
92        // Perform some pre-condition checks against the arguments and the function properties.
93        if self.is_elementwise() {
94            // For element-wise functions, all input arrays must be the same length.
95            if !args
96                .inputs
97                .iter()
98                .filter_map(|input| input.array())
99                .map(|array| array.len())
100                .all_equal()
101            {
102                vortex_bail!(
103                    "Compute function {} is elementwise but input arrays have different lengths",
104                    self.id
105                );
106            }
107        }
108
109        let expected_dtype = self.vtable.return_dtype(args)?;
110        let expected_len = self.vtable.return_len(args)?;
111
112        let output = self.vtable.invoke(args, &self.kernels.read())?;
113
114        if output.dtype() != &expected_dtype {
115            vortex_bail!(
116                "Internal error: compute function {} returned a result of type {} but expected {}",
117                self.id,
118                output.dtype(),
119                &expected_dtype
120            );
121        }
122        if output.len() != expected_len {
123            vortex_bail!(
124                "Internal error: compute function {} returned a result of length {} but expected {}",
125                self.id,
126                output.len(),
127                expected_len
128            );
129        }
130
131        Ok(output)
132    }
133
134    /// Compute the return type of the function given the input arguments.
135    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
136        self.vtable.return_dtype(args)
137    }
138
139    /// Compute the return length of the function given the input arguments.
140    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
141        self.vtable.return_len(args)
142    }
143
144    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
145    pub fn is_elementwise(&self) -> bool {
146        // TODO(ngates): should this just be a constant passed in the constructor?
147        self.vtable.is_elementwise()
148    }
149
150    /// Returns the compute function's kernels.
151    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
152        self.kernels.read().to_vec()
153    }
154}
155
156/// VTable for the implementation of a compute function.
157pub trait ComputeFnVTable: 'static + Send + Sync {
158    /// Invokes the compute function entry-point with the given input arguments and options.
159    ///
160    /// The entry-point logic can short-circuit compute using statistics, update result array
161    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
162    /// to successfully compute a result.
163    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
164    -> VortexResult<Output>;
165
166    /// Computes the return type of the function given the input arguments.
167    ///
168    /// All kernel implementations will be validated to return the [`DType`] as computed here.
169    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
170
171    /// Computes the return length of the function given the input arguments.
172    ///
173    /// All kernel implementations will be validated to return the len as computed here.
174    /// Scalars are considered to have length 1.
175    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
176
177    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
178    /// input and no information is shared between elements.
179    ///
180    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
181    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
182    ///
183    /// All input arrays to an elementwise function *must* have the same length.
184    fn is_elementwise(&self) -> bool;
185}
186
187/// Arguments to a compute function invocation.
188#[derive(Clone)]
189pub struct InvocationArgs<'a> {
190    pub inputs: &'a [Input<'a>],
191    pub options: &'a dyn Options,
192}
193
194/// For unary compute functions, it's useful to just have this short-cut.
195pub struct UnaryArgs<'a, O: Options> {
196    pub array: &'a dyn Array,
197    pub options: &'a O,
198}
199
200impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
201    type Error = VortexError;
202
203    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
204        if value.inputs.len() != 1 {
205            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
206        }
207        let array = value.inputs[0]
208            .array()
209            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
210        let options =
211            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
212                vortex_err!("Expected options to be of type {}", type_name::<O>())
213            })?;
214        Ok(UnaryArgs { array, options })
215    }
216}
217
218/// For binary compute functions, it's useful to just have this short-cut.
219pub struct BinaryArgs<'a, O: Options> {
220    pub lhs: &'a dyn Array,
221    pub rhs: &'a dyn Array,
222    pub options: &'a O,
223}
224
225impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
226    type Error = VortexError;
227
228    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
229        if value.inputs.len() != 2 {
230            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
231        }
232        let lhs = value.inputs[0]
233            .array()
234            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
235        let rhs = value.inputs[1]
236            .array()
237            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
238        let options =
239            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
240                vortex_err!("Expected options to be of type {}", type_name::<O>())
241            })?;
242        Ok(BinaryArgs { lhs, rhs, options })
243    }
244}
245
246/// Input to a compute function.
247pub enum Input<'a> {
248    Scalar(&'a Scalar),
249    Array(&'a dyn Array),
250    Mask(&'a Mask),
251    Builder(&'a mut dyn ArrayBuilder),
252    DType(&'a DType),
253}
254
255impl Debug for Input<'_> {
256    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
257        let mut f = f.debug_struct("Input");
258        match self {
259            Input::Scalar(scalar) => f.field("Scalar", scalar),
260            Input::Array(array) => f.field("Array", array),
261            Input::Mask(mask) => f.field("Mask", mask),
262            Input::Builder(builder) => f.field("Builder", &builder.len()),
263            Input::DType(dtype) => f.field("DType", dtype),
264        };
265        f.finish()
266    }
267}
268
269impl<'a> From<&'a dyn Array> for Input<'a> {
270    fn from(value: &'a dyn Array) -> Self {
271        Input::Array(value)
272    }
273}
274
275impl<'a> From<&'a Scalar> for Input<'a> {
276    fn from(value: &'a Scalar) -> Self {
277        Input::Scalar(value)
278    }
279}
280
281impl<'a> From<&'a Mask> for Input<'a> {
282    fn from(value: &'a Mask) -> Self {
283        Input::Mask(value)
284    }
285}
286
287impl<'a> From<&'a DType> for Input<'a> {
288    fn from(value: &'a DType) -> Self {
289        Input::DType(value)
290    }
291}
292
293impl<'a> Input<'a> {
294    pub fn scalar(&self) -> Option<&'a Scalar> {
295        match self {
296            Input::Scalar(scalar) => Some(*scalar),
297            _ => None,
298        }
299    }
300
301    pub fn array(&self) -> Option<&'a dyn Array> {
302        match self {
303            Input::Array(array) => Some(*array),
304            _ => None,
305        }
306    }
307
308    pub fn mask(&self) -> Option<&'a Mask> {
309        match self {
310            Input::Mask(mask) => Some(*mask),
311            _ => None,
312        }
313    }
314
315    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
316        match self {
317            Input::Builder(builder) => Some(*builder),
318            _ => None,
319        }
320    }
321
322    pub fn dtype(&self) -> Option<&'a DType> {
323        match self {
324            Input::DType(dtype) => Some(*dtype),
325            _ => None,
326        }
327    }
328}
329
330/// Output from a compute function.
331#[derive(Debug)]
332pub enum Output {
333    Scalar(Scalar),
334    Array(ArrayRef),
335}
336
337#[allow(clippy::len_without_is_empty)]
338impl Output {
339    pub fn dtype(&self) -> &DType {
340        match self {
341            Output::Scalar(scalar) => scalar.dtype(),
342            Output::Array(array) => array.dtype(),
343        }
344    }
345
346    pub fn len(&self) -> usize {
347        match self {
348            Output::Scalar(_) => 1,
349            Output::Array(array) => array.len(),
350        }
351    }
352
353    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
354        match self {
355            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
356            Output::Scalar(scalar) => Ok(scalar),
357        }
358    }
359
360    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
361        match self {
362            Output::Array(array) => Ok(array),
363            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
364        }
365    }
366}
367
368impl From<ArrayRef> for Output {
369    fn from(value: ArrayRef) -> Self {
370        Output::Array(value)
371    }
372}
373
374impl From<Scalar> for Output {
375    fn from(value: Scalar) -> Self {
376        Output::Scalar(value)
377    }
378}
379
380/// Options for a compute function invocation.
381pub trait Options: 'static {
382    fn as_any(&self) -> &dyn Any;
383}
384
385impl Options for () {
386    fn as_any(&self) -> &dyn Any {
387        self
388    }
389}
390
391/// Compute functions can ask arrays for compute kernels for a given invocation.
392///
393/// The kernel is invoked with the input arguments and options, and can return `None` if it is
394/// unable to compute the result for the given inputs due to missing implementation logic.
395/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
396/// is indicating that it cannot compute the result for the given inputs, and another kernel should
397/// be tried. *Not* that the given inputs are invalid for the compute function.
398///
399/// If the kernel fails to compute a result, it should return a `Some` with the error.
400pub trait Kernel: 'static + Send + Sync + Debug {
401    /// Invokes the kernel with the given input arguments and options.
402    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
403}
404
405/// Register a kernel for a compute function.
406/// See each compute function for the correct type of kernel to register.
407#[macro_export]
408macro_rules! register_kernel {
409    ($T:expr) => {
410        $crate::aliases::inventory::submit!($T);
411    };
412}