Skip to main content

vortex_array/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute kernels on top of Vortex Arrays.
5//!
6//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7//! and filter Vortex Arrays in their encoded forms.
8//!
9//! Every array encoding has the ability to implement their own efficient implementations of these
10//! operators, else we will decode, and perform the equivalent operator from Arrow.
11
12use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use is_constant::*;
19pub use is_sorted::*;
20use itertools::Itertools;
21pub use min_max::*;
22pub use nan_count::*;
23use parking_lot::RwLock;
24pub use sum::*;
25use vortex_error::VortexError;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29use vortex_mask::Mask;
30
31use crate::ArrayRef;
32use crate::DynArray;
33use crate::builders::ArrayBuilder;
34use crate::dtype::DType;
35use crate::scalar::Scalar;
36
37#[cfg(feature = "arbitrary")]
38mod arbitrary;
39#[cfg(feature = "_test-harness")]
40pub mod conformance;
41mod is_constant;
42mod is_sorted;
43mod min_max;
44mod nan_count;
45mod sum;
46
47/// An instance of a compute function holding the implementation vtable and a set of registered
48/// compute kernels.
49pub struct ComputeFn {
50    id: ArcRef<str>,
51    vtable: ArcRef<dyn ComputeFnVTable>,
52    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
53}
54
55impl ComputeFn {
56    /// Create a new compute function from the given [`ComputeFnVTable`].
57    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
58        Self {
59            id,
60            vtable,
61            kernels: Default::default(),
62        }
63    }
64
65    /// Returns the string identifier of the compute function.
66    pub fn id(&self) -> &ArcRef<str> {
67        &self.id
68    }
69
70    /// Register a kernel for the compute function.
71    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
72        self.kernels.write().push(kernel);
73    }
74
75    /// Invokes the compute function with the given arguments.
76    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
77        // Perform some pre-condition checks against the arguments and the function properties.
78        if self.is_elementwise() {
79            // For element-wise functions, all input arrays must be the same length.
80            if !args
81                .inputs
82                .iter()
83                .filter_map(|input| input.array())
84                .map(|array| array.len())
85                .all_equal()
86            {
87                vortex_bail!(
88                    "Compute function {} is elementwise but input arrays have different lengths",
89                    self.id
90                );
91            }
92        }
93
94        let expected_dtype = self.vtable.return_dtype(args)?;
95        let expected_len = self.vtable.return_len(args)?;
96
97        let output = self.vtable.invoke(args, &self.kernels.read())?;
98
99        if output.dtype() != &expected_dtype {
100            vortex_bail!(
101                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
102                self.id,
103                output.dtype(),
104                &expected_dtype,
105                args.inputs
106                    .iter()
107                    .filter_map(|input| input.array())
108                    .format_with(",", |array, f| f(&array.encoding_id()))
109            );
110        }
111        if output.len() != expected_len {
112            vortex_bail!(
113                "Internal error: compute function {} returned a result of length {} but expected {}",
114                self.id,
115                output.len(),
116                expected_len
117            );
118        }
119
120        Ok(output)
121    }
122
123    /// Compute the return type of the function given the input arguments.
124    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
125        self.vtable.return_dtype(args)
126    }
127
128    /// Compute the return length of the function given the input arguments.
129    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
130        self.vtable.return_len(args)
131    }
132
133    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
134    pub fn is_elementwise(&self) -> bool {
135        // TODO(ngates): should this just be a constant passed in the constructor?
136        self.vtable.is_elementwise()
137    }
138
139    /// Returns the compute function's kernels.
140    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
141        self.kernels.read().to_vec()
142    }
143}
144
145/// VTable for the implementation of a compute function.
146pub trait ComputeFnVTable: 'static + Send + Sync {
147    /// Invokes the compute function entry-point with the given input arguments and options.
148    ///
149    /// The entry-point logic can short-circuit compute using statistics, update result array
150    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
151    /// to successfully compute a result.
152    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
153    -> VortexResult<Output>;
154
155    /// Computes the return type of the function given the input arguments.
156    ///
157    /// All kernel implementations will be validated to return the [`DType`] as computed here.
158    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
159
160    /// Computes the return length of the function given the input arguments.
161    ///
162    /// All kernel implementations will be validated to return the len as computed here.
163    /// Scalars are considered to have length 1.
164    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
165
166    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
167    /// input and no information is shared between elements.
168    ///
169    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
170    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
171    ///
172    /// All input arrays to an elementwise function *must* have the same length.
173    fn is_elementwise(&self) -> bool;
174}
175
176/// Arguments to a compute function invocation.
177#[derive(Clone)]
178pub struct InvocationArgs<'a> {
179    pub inputs: &'a [Input<'a>],
180    pub options: &'a dyn Options,
181}
182
183/// For unary compute functions, it's useful to just have this short-cut.
184pub struct UnaryArgs<'a, O: Options> {
185    pub array: &'a dyn DynArray,
186    pub options: &'a O,
187}
188
189impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
190    type Error = VortexError;
191
192    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
193        if value.inputs.len() != 1 {
194            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
195        }
196        let array = value.inputs[0]
197            .array()
198            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
199        let options =
200            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
201                vortex_err!("Expected options to be of type {}", type_name::<O>())
202            })?;
203        Ok(UnaryArgs { array, options })
204    }
205}
206
207/// For binary compute functions, it's useful to just have this short-cut.
208pub struct BinaryArgs<'a, O: Options> {
209    pub lhs: &'a dyn DynArray,
210    pub rhs: &'a dyn DynArray,
211    pub options: &'a O,
212}
213
214impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
215    type Error = VortexError;
216
217    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
218        if value.inputs.len() != 2 {
219            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
220        }
221        let lhs = value.inputs[0]
222            .array()
223            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
224        let rhs = value.inputs[1]
225            .array()
226            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
227        let options =
228            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
229                vortex_err!("Expected options to be of type {}", type_name::<O>())
230            })?;
231        Ok(BinaryArgs { lhs, rhs, options })
232    }
233}
234
235/// Input to a compute function.
236pub enum Input<'a> {
237    Scalar(&'a Scalar),
238    Array(&'a dyn DynArray),
239    Mask(&'a Mask),
240    Builder(&'a mut dyn ArrayBuilder),
241    DType(&'a DType),
242}
243
244impl Debug for Input<'_> {
245    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
246        let mut f = f.debug_struct("Input");
247        match self {
248            Input::Scalar(scalar) => f.field("Scalar", scalar),
249            Input::Array(array) => f.field("Array", array),
250            Input::Mask(mask) => f.field("Mask", mask),
251            Input::Builder(builder) => f.field("Builder", &builder.len()),
252            Input::DType(dtype) => f.field("DType", dtype),
253        };
254        f.finish()
255    }
256}
257
258impl<'a> From<&'a dyn DynArray> for Input<'a> {
259    fn from(value: &'a dyn DynArray) -> Self {
260        Input::Array(value)
261    }
262}
263
264impl<'a> From<&'a ArrayRef> for Input<'a> {
265    fn from(value: &'a ArrayRef) -> Self {
266        Input::Array(value.as_ref())
267    }
268}
269
270impl<'a> From<&'a Scalar> for Input<'a> {
271    fn from(value: &'a Scalar) -> Self {
272        Input::Scalar(value)
273    }
274}
275
276impl<'a> From<&'a Mask> for Input<'a> {
277    fn from(value: &'a Mask) -> Self {
278        Input::Mask(value)
279    }
280}
281
282impl<'a> From<&'a DType> for Input<'a> {
283    fn from(value: &'a DType) -> Self {
284        Input::DType(value)
285    }
286}
287
288impl<'a> Input<'a> {
289    pub fn scalar(&self) -> Option<&'a Scalar> {
290        if let Input::Scalar(scalar) = self {
291            Some(*scalar)
292        } else {
293            None
294        }
295    }
296
297    pub fn array(&self) -> Option<&'a dyn DynArray> {
298        if let Input::Array(array) = self {
299            Some(*array)
300        } else {
301            None
302        }
303    }
304
305    pub fn mask(&self) -> Option<&'a Mask> {
306        if let Input::Mask(mask) = self {
307            Some(*mask)
308        } else {
309            None
310        }
311    }
312
313    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
314        if let Input::Builder(builder) = self {
315            Some(*builder)
316        } else {
317            None
318        }
319    }
320
321    pub fn dtype(&self) -> Option<&'a DType> {
322        if let Input::DType(dtype) = self {
323            Some(*dtype)
324        } else {
325            None
326        }
327    }
328}
329
330/// Output from a compute function.
331#[derive(Debug)]
332pub enum Output {
333    Scalar(Scalar),
334    Array(ArrayRef),
335}
336
337#[expect(
338    clippy::len_without_is_empty,
339    reason = "Output is always non-empty (scalar has len 1)"
340)]
341impl Output {
342    pub fn dtype(&self) -> &DType {
343        match self {
344            Output::Scalar(scalar) => scalar.dtype(),
345            Output::Array(array) => array.dtype(),
346        }
347    }
348
349    pub fn len(&self) -> usize {
350        match self {
351            Output::Scalar(_) => 1,
352            Output::Array(array) => array.len(),
353        }
354    }
355
356    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
357        match self {
358            Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
359            Output::Scalar(scalar) => Ok(scalar),
360        }
361    }
362
363    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
364        match self {
365            Output::Array(array) => Ok(array),
366            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
367        }
368    }
369}
370
371impl From<ArrayRef> for Output {
372    fn from(value: ArrayRef) -> Self {
373        Output::Array(value)
374    }
375}
376
377impl From<Scalar> for Output {
378    fn from(value: Scalar) -> Self {
379        Output::Scalar(value)
380    }
381}
382
383/// Options for a compute function invocation.
384pub trait Options: 'static {
385    fn as_any(&self) -> &dyn Any;
386}
387
388impl Options for () {
389    fn as_any(&self) -> &dyn Any {
390        self
391    }
392}
393
394/// Compute functions can ask arrays for compute kernels for a given invocation.
395///
396/// The kernel is invoked with the input arguments and options, and can return `None` if it is
397/// unable to compute the result for the given inputs due to missing implementation logic.
398/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
399/// is indicating that it cannot compute the result for the given inputs, and another kernel should
400/// be tried. *Not* that the given inputs are invalid for the compute function.
401///
402/// If the kernel fails to compute a result, it should return a `Some` with the error.
403pub trait Kernel: 'static + Send + Sync + Debug {
404    /// Invokes the kernel with the given input arguments and options.
405    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
406}
407
408/// Register a kernel for a compute function.
409/// See each compute function for the correct type of kernel to register.
410#[macro_export]
411macro_rules! register_kernel {
412    ($T:expr) => {
413        $crate::aliases::inventory::submit!($T);
414    };
415}