1use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11use std::sync::RwLock;
12
13use arcref::ArcRef;
14pub use between::*;
15pub use boolean::*;
16pub use cast::*;
17pub use compare::*;
18pub use fill_null::*;
19pub use filter::*;
20pub use invert::*;
21pub use is_constant::*;
22pub use is_sorted::*;
23use itertools::Itertools;
24pub use like::*;
25pub use list::*;
26pub use mask::*;
27pub use min_max::*;
28pub use nan_count::*;
29pub use numeric::*;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62pub struct ComputeFn {
65    id: ArcRef<str>,
66    vtable: ArcRef<dyn ComputeFnVTable>,
67    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73        Self {
74            id,
75            vtable,
76            kernels: Default::default(),
77        }
78    }
79
80    pub fn id(&self) -> &ArcRef<str> {
82        &self.id
83    }
84
85    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87        self.kernels
88            .write()
89            .vortex_expect("poisoned lock")
90            .push(kernel);
91    }
92
93    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
95        if self.is_elementwise() {
97            if !args
99                .inputs
100                .iter()
101                .filter_map(|input| input.array())
102                .map(|array| array.len())
103                .all_equal()
104            {
105                vortex_bail!(
106                    "Compute function {} is elementwise but input arrays have different lengths",
107                    self.id
108                );
109            }
110        }
111
112        let expected_dtype = self.vtable.return_dtype(args)?;
113        let expected_len = self.vtable.return_len(args)?;
114
115        let output = self
116            .vtable
117            .invoke(args, &self.kernels.read().vortex_expect("poisoned lock"))?;
118
119        if output.dtype() != &expected_dtype {
120            vortex_bail!(
121                "Internal error: compute function {} returned a result of type {} but expected {}",
122                self.id,
123                output.dtype(),
124                &expected_dtype
125            );
126        }
127        if output.len() != expected_len {
128            vortex_bail!(
129                "Internal error: compute function {} returned a result of length {} but expected {}",
130                self.id,
131                output.len(),
132                expected_len
133            );
134        }
135
136        Ok(output)
137    }
138
139    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
141        self.vtable.return_dtype(args)
142    }
143
144    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
146        self.vtable.return_len(args)
147    }
148
149    pub fn is_elementwise(&self) -> bool {
151        self.vtable.is_elementwise()
153    }
154
155    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
157        self.kernels.read().vortex_expect("poisoned lock").to_vec()
158    }
159}
160
161pub trait ComputeFnVTable: 'static + Send + Sync {
163    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
169    -> VortexResult<Output>;
170
171    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
175
176    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
181
182    fn is_elementwise(&self) -> bool;
190}
191
192#[derive(Clone)]
194pub struct InvocationArgs<'a> {
195    pub inputs: &'a [Input<'a>],
196    pub options: &'a dyn Options,
197}
198
199pub struct UnaryArgs<'a, O: Options> {
201    pub array: &'a dyn Array,
202    pub options: &'a O,
203}
204
205impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
206    type Error = VortexError;
207
208    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
209        if value.inputs.len() != 1 {
210            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
211        }
212        let array = value.inputs[0]
213            .array()
214            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
215        let options =
216            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
217                vortex_err!("Expected options to be of type {}", type_name::<O>())
218            })?;
219        Ok(UnaryArgs { array, options })
220    }
221}
222
223pub struct BinaryArgs<'a, O: Options> {
225    pub lhs: &'a dyn Array,
226    pub rhs: &'a dyn Array,
227    pub options: &'a O,
228}
229
230impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
231    type Error = VortexError;
232
233    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
234        if value.inputs.len() != 2 {
235            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
236        }
237        let lhs = value.inputs[0]
238            .array()
239            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
240        let rhs = value.inputs[1]
241            .array()
242            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
243        let options =
244            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
245                vortex_err!("Expected options to be of type {}", type_name::<O>())
246            })?;
247        Ok(BinaryArgs { lhs, rhs, options })
248    }
249}
250
251pub enum Input<'a> {
253    Scalar(&'a Scalar),
254    Array(&'a dyn Array),
255    Mask(&'a Mask),
256    Builder(&'a mut dyn ArrayBuilder),
257    DType(&'a DType),
258}
259
260impl Debug for Input<'_> {
261    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
262        let mut f = f.debug_struct("Input");
263        match self {
264            Input::Scalar(scalar) => f.field("Scalar", scalar),
265            Input::Array(array) => f.field("Array", array),
266            Input::Mask(mask) => f.field("Mask", mask),
267            Input::Builder(builder) => f.field("Builder", &builder.len()),
268            Input::DType(dtype) => f.field("DType", dtype),
269        };
270        f.finish()
271    }
272}
273
274impl<'a> From<&'a dyn Array> for Input<'a> {
275    fn from(value: &'a dyn Array) -> Self {
276        Input::Array(value)
277    }
278}
279
280impl<'a> From<&'a Scalar> for Input<'a> {
281    fn from(value: &'a Scalar) -> Self {
282        Input::Scalar(value)
283    }
284}
285
286impl<'a> From<&'a Mask> for Input<'a> {
287    fn from(value: &'a Mask) -> Self {
288        Input::Mask(value)
289    }
290}
291
292impl<'a> From<&'a DType> for Input<'a> {
293    fn from(value: &'a DType) -> Self {
294        Input::DType(value)
295    }
296}
297
298impl<'a> Input<'a> {
299    pub fn scalar(&self) -> Option<&'a Scalar> {
300        match self {
301            Input::Scalar(scalar) => Some(*scalar),
302            _ => None,
303        }
304    }
305
306    pub fn array(&self) -> Option<&'a dyn Array> {
307        match self {
308            Input::Array(array) => Some(*array),
309            _ => None,
310        }
311    }
312
313    pub fn mask(&self) -> Option<&'a Mask> {
314        match self {
315            Input::Mask(mask) => Some(*mask),
316            _ => None,
317        }
318    }
319
320    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
321        match self {
322            Input::Builder(builder) => Some(*builder),
323            _ => None,
324        }
325    }
326
327    pub fn dtype(&self) -> Option<&'a DType> {
328        match self {
329            Input::DType(dtype) => Some(*dtype),
330            _ => None,
331        }
332    }
333}
334
335#[derive(Debug)]
337pub enum Output {
338    Scalar(Scalar),
339    Array(ArrayRef),
340}
341
342#[allow(clippy::len_without_is_empty)]
343impl Output {
344    pub fn dtype(&self) -> &DType {
345        match self {
346            Output::Scalar(scalar) => scalar.dtype(),
347            Output::Array(array) => array.dtype(),
348        }
349    }
350
351    pub fn len(&self) -> usize {
352        match self {
353            Output::Scalar(_) => 1,
354            Output::Array(array) => array.len(),
355        }
356    }
357
358    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
359        match self {
360            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
361            Output::Scalar(scalar) => Ok(scalar),
362        }
363    }
364
365    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
366        match self {
367            Output::Array(array) => Ok(array),
368            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
369        }
370    }
371}
372
373impl From<ArrayRef> for Output {
374    fn from(value: ArrayRef) -> Self {
375        Output::Array(value)
376    }
377}
378
379impl From<Scalar> for Output {
380    fn from(value: Scalar) -> Self {
381        Output::Scalar(value)
382    }
383}
384
385pub trait Options: 'static {
387    fn as_any(&self) -> &dyn Any;
388}
389
390impl Options for () {
391    fn as_any(&self) -> &dyn Any {
392        self
393    }
394}
395
396pub trait Kernel: 'static + Send + Sync + Debug {
406    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
408}
409
410#[macro_export]
413macro_rules! register_kernel {
414    ($T:expr) => {
415        $crate::aliases::inventory::submit!($T);
416    };
417}