1use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11
12use arcref::ArcRef;
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::*;
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::*;
24pub use list_contains::*;
25pub use mask::*;
26pub use min_max::*;
27pub use nan_count::*;
28pub use numeric::*;
29use parking_lot::RwLock;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list_contains;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62pub struct ComputeFn {
65 id: ArcRef<str>,
66 vtable: ArcRef<dyn ComputeFnVTable>,
67 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73 Self {
74 id,
75 vtable,
76 kernels: Default::default(),
77 }
78 }
79
80 pub fn id(&self) -> &ArcRef<str> {
82 &self.id
83 }
84
85 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87 self.kernels.write().push(kernel);
88 }
89
90 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
92 if self.is_elementwise() {
94 if !args
96 .inputs
97 .iter()
98 .filter_map(|input| input.array())
99 .map(|array| array.len())
100 .all_equal()
101 {
102 vortex_bail!(
103 "Compute function {} is elementwise but input arrays have different lengths",
104 self.id
105 );
106 }
107 }
108
109 let expected_dtype = self.vtable.return_dtype(args)?;
110 let expected_len = self.vtable.return_len(args)?;
111
112 let output = self.vtable.invoke(args, &self.kernels.read())?;
113
114 if output.dtype() != &expected_dtype {
115 vortex_bail!(
116 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
117 self.id,
118 output.dtype(),
119 &expected_dtype,
120 args.inputs
121 .iter()
122 .filter_map(|input| input.array())
123 .format_with(",", |array, f| f(&array.tree_display()))
124 );
125 }
126 if output.len() != expected_len {
127 vortex_bail!(
128 "Internal error: compute function {} returned a result of length {} but expected {}",
129 self.id,
130 output.len(),
131 expected_len
132 );
133 }
134
135 Ok(output)
136 }
137
138 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
140 self.vtable.return_dtype(args)
141 }
142
143 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
145 self.vtable.return_len(args)
146 }
147
148 pub fn is_elementwise(&self) -> bool {
150 self.vtable.is_elementwise()
152 }
153
154 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
156 self.kernels.read().to_vec()
157 }
158}
159
160pub trait ComputeFnVTable: 'static + Send + Sync {
162 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
168 -> VortexResult<Output>;
169
170 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
174
175 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
180
181 fn is_elementwise(&self) -> bool;
189}
190
191#[derive(Clone)]
193pub struct InvocationArgs<'a> {
194 pub inputs: &'a [Input<'a>],
195 pub options: &'a dyn Options,
196}
197
198pub struct UnaryArgs<'a, O: Options> {
200 pub array: &'a dyn Array,
201 pub options: &'a O,
202}
203
204impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
205 type Error = VortexError;
206
207 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
208 if value.inputs.len() != 1 {
209 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
210 }
211 let array = value.inputs[0]
212 .array()
213 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
214 let options =
215 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
216 vortex_err!("Expected options to be of type {}", type_name::<O>())
217 })?;
218 Ok(UnaryArgs { array, options })
219 }
220}
221
222pub struct BinaryArgs<'a, O: Options> {
224 pub lhs: &'a dyn Array,
225 pub rhs: &'a dyn Array,
226 pub options: &'a O,
227}
228
229impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
230 type Error = VortexError;
231
232 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
233 if value.inputs.len() != 2 {
234 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
235 }
236 let lhs = value.inputs[0]
237 .array()
238 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
239 let rhs = value.inputs[1]
240 .array()
241 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
242 let options =
243 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
244 vortex_err!("Expected options to be of type {}", type_name::<O>())
245 })?;
246 Ok(BinaryArgs { lhs, rhs, options })
247 }
248}
249
250pub enum Input<'a> {
252 Scalar(&'a Scalar),
253 Array(&'a dyn Array),
254 Mask(&'a Mask),
255 Builder(&'a mut dyn ArrayBuilder),
256 DType(&'a DType),
257}
258
259impl Debug for Input<'_> {
260 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
261 let mut f = f.debug_struct("Input");
262 match self {
263 Input::Scalar(scalar) => f.field("Scalar", scalar),
264 Input::Array(array) => f.field("Array", array),
265 Input::Mask(mask) => f.field("Mask", mask),
266 Input::Builder(builder) => f.field("Builder", &builder.len()),
267 Input::DType(dtype) => f.field("DType", dtype),
268 };
269 f.finish()
270 }
271}
272
273impl<'a> From<&'a dyn Array> for Input<'a> {
274 fn from(value: &'a dyn Array) -> Self {
275 Input::Array(value)
276 }
277}
278
279impl<'a> From<&'a Scalar> for Input<'a> {
280 fn from(value: &'a Scalar) -> Self {
281 Input::Scalar(value)
282 }
283}
284
285impl<'a> From<&'a Mask> for Input<'a> {
286 fn from(value: &'a Mask) -> Self {
287 Input::Mask(value)
288 }
289}
290
291impl<'a> From<&'a DType> for Input<'a> {
292 fn from(value: &'a DType) -> Self {
293 Input::DType(value)
294 }
295}
296
297impl<'a> Input<'a> {
298 pub fn scalar(&self) -> Option<&'a Scalar> {
299 match self {
300 Input::Scalar(scalar) => Some(*scalar),
301 _ => None,
302 }
303 }
304
305 pub fn array(&self) -> Option<&'a dyn Array> {
306 match self {
307 Input::Array(array) => Some(*array),
308 _ => None,
309 }
310 }
311
312 pub fn mask(&self) -> Option<&'a Mask> {
313 match self {
314 Input::Mask(mask) => Some(*mask),
315 _ => None,
316 }
317 }
318
319 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
320 match self {
321 Input::Builder(builder) => Some(*builder),
322 _ => None,
323 }
324 }
325
326 pub fn dtype(&self) -> Option<&'a DType> {
327 match self {
328 Input::DType(dtype) => Some(*dtype),
329 _ => None,
330 }
331 }
332}
333
334#[derive(Debug)]
336pub enum Output {
337 Scalar(Scalar),
338 Array(ArrayRef),
339}
340
341#[allow(clippy::len_without_is_empty)]
342impl Output {
343 pub fn dtype(&self) -> &DType {
344 match self {
345 Output::Scalar(scalar) => scalar.dtype(),
346 Output::Array(array) => array.dtype(),
347 }
348 }
349
350 pub fn len(&self) -> usize {
351 match self {
352 Output::Scalar(_) => 1,
353 Output::Array(array) => array.len(),
354 }
355 }
356
357 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
358 match self {
359 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
360 Output::Scalar(scalar) => Ok(scalar),
361 }
362 }
363
364 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
365 match self {
366 Output::Array(array) => Ok(array),
367 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
368 }
369 }
370}
371
372impl From<ArrayRef> for Output {
373 fn from(value: ArrayRef) -> Self {
374 Output::Array(value)
375 }
376}
377
378impl From<Scalar> for Output {
379 fn from(value: Scalar) -> Self {
380 Output::Scalar(value)
381 }
382}
383
384pub trait Options: 'static {
386 fn as_any(&self) -> &dyn Any;
387}
388
389impl Options for () {
390 fn as_any(&self) -> &dyn Any {
391 self
392 }
393}
394
395pub trait Kernel: 'static + Send + Sync + Debug {
405 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
407}
408
409#[macro_export]
412macro_rules! register_kernel {
413 ($T:expr) => {
414 $crate::aliases::inventory::submit!($T);
415 };
416}