1use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46pub mod arrays;
47mod between;
48mod boolean;
49mod cast;
50mod compare;
51#[cfg(feature = "test-harness")]
52pub mod conformance;
53mod fill_null;
54mod filter;
55mod invert;
56mod is_constant;
57mod is_sorted;
58mod like;
59mod list_contains;
60mod mask;
61mod min_max;
62mod nan_count;
63mod numeric;
64mod sum;
65mod take;
66mod zip;
67
68pub struct ComputeFn {
71 id: ArcRef<str>,
72 vtable: ArcRef<dyn ComputeFnVTable>,
73 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
74}
75
76pub fn warm_up_vtables() {
80 crate::arrow::warm_up_vtable();
81 #[allow(unused_qualifications)]
82 between::warm_up_vtable();
83 boolean::warm_up_vtable();
84 cast::warm_up_vtable();
85 compare::warm_up_vtable();
86 fill_null::warm_up_vtable();
87 filter::warm_up_vtable();
88 invert::warm_up_vtable();
89 is_constant::warm_up_vtable();
90 is_sorted::warm_up_vtable();
91 like::warm_up_vtable();
92 list_contains::warm_up_vtable();
93 mask::warm_up_vtable();
94 min_max::warm_up_vtable();
95 nan_count::warm_up_vtable();
96 numeric::warm_up_vtable();
97 sum::warm_up_vtable();
98 take::warm_up_vtable();
99 zip::warm_up_vtable();
100}
101
102impl ComputeFn {
103 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
105 Self {
106 id,
107 vtable,
108 kernels: Default::default(),
109 }
110 }
111
112 pub fn id(&self) -> &ArcRef<str> {
114 &self.id
115 }
116
117 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
119 self.kernels.write().push(kernel);
120 }
121
122 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
124 if self.is_elementwise() {
126 if !args
128 .inputs
129 .iter()
130 .filter_map(|input| input.array())
131 .map(|array| array.len())
132 .all_equal()
133 {
134 vortex_bail!(
135 "Compute function {} is elementwise but input arrays have different lengths",
136 self.id
137 );
138 }
139 }
140
141 let expected_dtype = self.vtable.return_dtype(args)?;
142 let expected_len = self.vtable.return_len(args)?;
143
144 let output = self.vtable.invoke(args, &self.kernels.read())?;
145
146 if output.dtype() != &expected_dtype {
147 vortex_bail!(
148 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
149 self.id,
150 output.dtype(),
151 &expected_dtype,
152 args.inputs
153 .iter()
154 .filter_map(|input| input.array())
155 .format_with(",", |array, f| f(&array.display_tree()))
156 );
157 }
158 if output.len() != expected_len {
159 vortex_bail!(
160 "Internal error: compute function {} returned a result of length {} but expected {}",
161 self.id,
162 output.len(),
163 expected_len
164 );
165 }
166
167 Ok(output)
168 }
169
170 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
172 self.vtable.return_dtype(args)
173 }
174
175 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
177 self.vtable.return_len(args)
178 }
179
180 pub fn is_elementwise(&self) -> bool {
182 self.vtable.is_elementwise()
184 }
185
186 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
188 self.kernels.read().to_vec()
189 }
190}
191
192pub trait ComputeFnVTable: 'static + Send + Sync {
194 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
200 -> VortexResult<Output>;
201
202 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
206
207 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
212
213 fn is_elementwise(&self) -> bool;
221}
222
223#[derive(Clone)]
225pub struct InvocationArgs<'a> {
226 pub inputs: &'a [Input<'a>],
227 pub options: &'a dyn Options,
228}
229
230pub struct UnaryArgs<'a, O: Options> {
232 pub array: &'a dyn Array,
233 pub options: &'a O,
234}
235
236impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
237 type Error = VortexError;
238
239 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
240 if value.inputs.len() != 1 {
241 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
242 }
243 let array = value.inputs[0]
244 .array()
245 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
246 let options =
247 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
248 vortex_err!("Expected options to be of type {}", type_name::<O>())
249 })?;
250 Ok(UnaryArgs { array, options })
251 }
252}
253
254pub struct BinaryArgs<'a, O: Options> {
256 pub lhs: &'a dyn Array,
257 pub rhs: &'a dyn Array,
258 pub options: &'a O,
259}
260
261impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
262 type Error = VortexError;
263
264 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
265 if value.inputs.len() != 2 {
266 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
267 }
268 let lhs = value.inputs[0]
269 .array()
270 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
271 let rhs = value.inputs[1]
272 .array()
273 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
274 let options =
275 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
276 vortex_err!("Expected options to be of type {}", type_name::<O>())
277 })?;
278 Ok(BinaryArgs { lhs, rhs, options })
279 }
280}
281
282pub enum Input<'a> {
284 Scalar(&'a Scalar),
285 Array(&'a dyn Array),
286 Mask(&'a Mask),
287 Builder(&'a mut dyn ArrayBuilder),
288 DType(&'a DType),
289}
290
291impl Debug for Input<'_> {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 let mut f = f.debug_struct("Input");
294 match self {
295 Input::Scalar(scalar) => f.field("Scalar", scalar),
296 Input::Array(array) => f.field("Array", array),
297 Input::Mask(mask) => f.field("Mask", mask),
298 Input::Builder(builder) => f.field("Builder", &builder.len()),
299 Input::DType(dtype) => f.field("DType", dtype),
300 };
301 f.finish()
302 }
303}
304
305impl<'a> From<&'a dyn Array> for Input<'a> {
306 fn from(value: &'a dyn Array) -> Self {
307 Input::Array(value)
308 }
309}
310
311impl<'a> From<&'a Scalar> for Input<'a> {
312 fn from(value: &'a Scalar) -> Self {
313 Input::Scalar(value)
314 }
315}
316
317impl<'a> From<&'a Mask> for Input<'a> {
318 fn from(value: &'a Mask) -> Self {
319 Input::Mask(value)
320 }
321}
322
323impl<'a> From<&'a DType> for Input<'a> {
324 fn from(value: &'a DType) -> Self {
325 Input::DType(value)
326 }
327}
328
329impl<'a> Input<'a> {
330 pub fn scalar(&self) -> Option<&'a Scalar> {
331 if let Input::Scalar(scalar) = self {
332 Some(*scalar)
333 } else {
334 None
335 }
336 }
337
338 pub fn array(&self) -> Option<&'a dyn Array> {
339 if let Input::Array(array) = self {
340 Some(*array)
341 } else {
342 None
343 }
344 }
345
346 pub fn mask(&self) -> Option<&'a Mask> {
347 if let Input::Mask(mask) = self {
348 Some(*mask)
349 } else {
350 None
351 }
352 }
353
354 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
355 if let Input::Builder(builder) = self {
356 Some(*builder)
357 } else {
358 None
359 }
360 }
361
362 pub fn dtype(&self) -> Option<&'a DType> {
363 if let Input::DType(dtype) = self {
364 Some(*dtype)
365 } else {
366 None
367 }
368 }
369}
370
371#[derive(Debug)]
373pub enum Output {
374 Scalar(Scalar),
375 Array(ArrayRef),
376}
377
378#[allow(clippy::len_without_is_empty)]
379impl Output {
380 pub fn dtype(&self) -> &DType {
381 match self {
382 Output::Scalar(scalar) => scalar.dtype(),
383 Output::Array(array) => array.dtype(),
384 }
385 }
386
387 pub fn len(&self) -> usize {
388 match self {
389 Output::Scalar(_) => 1,
390 Output::Array(array) => array.len(),
391 }
392 }
393
394 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
395 match self {
396 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
397 Output::Scalar(scalar) => Ok(scalar),
398 }
399 }
400
401 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
402 match self {
403 Output::Array(array) => Ok(array),
404 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
405 }
406 }
407}
408
409impl From<ArrayRef> for Output {
410 fn from(value: ArrayRef) -> Self {
411 Output::Array(value)
412 }
413}
414
415impl From<Scalar> for Output {
416 fn from(value: Scalar) -> Self {
417 Output::Scalar(value)
418 }
419}
420
421pub trait Options: 'static {
423 fn as_any(&self) -> &dyn Any;
424}
425
426impl Options for () {
427 fn as_any(&self) -> &dyn Any {
428 self
429 }
430}
431
432pub trait Kernel: 'static + Send + Sync + Debug {
442 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
444}
445
446#[macro_export]
449macro_rules! register_kernel {
450 ($T:expr) => {
451 $crate::aliases::inventory::submit!($T);
452 };
453}