1use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46mod between;
47mod boolean;
48mod cast;
49mod compare;
50#[cfg(feature = "test-harness")]
51pub mod conformance;
52mod fill_null;
53mod filter;
54mod invert;
55mod is_constant;
56mod is_sorted;
57mod like;
58mod list_contains;
59mod mask;
60mod min_max;
61mod nan_count;
62mod numeric;
63mod sum;
64mod take;
65mod zip;
66
67pub struct ComputeFn {
70 id: ArcRef<str>,
71 vtable: ArcRef<dyn ComputeFnVTable>,
72 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
73}
74
75pub fn warm_up_vtables() {
79 crate::arrow::warm_up_vtable();
80 #[allow(unused_qualifications)]
81 between::warm_up_vtable();
82 boolean::warm_up_vtable();
83 cast::warm_up_vtable();
84 compare::warm_up_vtable();
85 fill_null::warm_up_vtable();
86 filter::warm_up_vtable();
87 invert::warm_up_vtable();
88 is_constant::warm_up_vtable();
89 is_sorted::warm_up_vtable();
90 like::warm_up_vtable();
91 list_contains::warm_up_vtable();
92 mask::warm_up_vtable();
93 min_max::warm_up_vtable();
94 nan_count::warm_up_vtable();
95 numeric::warm_up_vtable();
96 sum::warm_up_vtable();
97 take::warm_up_vtable();
98 zip::warm_up_vtable();
99}
100
101impl ComputeFn {
102 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
104 Self {
105 id,
106 vtable,
107 kernels: Default::default(),
108 }
109 }
110
111 pub fn id(&self) -> &ArcRef<str> {
113 &self.id
114 }
115
116 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
118 self.kernels.write().push(kernel);
119 }
120
121 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
123 if self.is_elementwise() {
125 if !args
127 .inputs
128 .iter()
129 .filter_map(|input| input.array())
130 .map(|array| array.len())
131 .all_equal()
132 {
133 vortex_bail!(
134 "Compute function {} is elementwise but input arrays have different lengths",
135 self.id
136 );
137 }
138 }
139
140 let expected_dtype = self.vtable.return_dtype(args)?;
141 let expected_len = self.vtable.return_len(args)?;
142
143 let output = self.vtable.invoke(args, &self.kernels.read())?;
144
145 if output.dtype() != &expected_dtype {
146 vortex_bail!(
147 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
148 self.id,
149 output.dtype(),
150 &expected_dtype,
151 args.inputs
152 .iter()
153 .filter_map(|input| input.array())
154 .format_with(",", |array, f| f(&array.display_tree()))
155 );
156 }
157 if output.len() != expected_len {
158 vortex_bail!(
159 "Internal error: compute function {} returned a result of length {} but expected {}",
160 self.id,
161 output.len(),
162 expected_len
163 );
164 }
165
166 Ok(output)
167 }
168
169 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
171 self.vtable.return_dtype(args)
172 }
173
174 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
176 self.vtable.return_len(args)
177 }
178
179 pub fn is_elementwise(&self) -> bool {
181 self.vtable.is_elementwise()
183 }
184
185 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
187 self.kernels.read().to_vec()
188 }
189}
190
191pub trait ComputeFnVTable: 'static + Send + Sync {
193 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
199 -> VortexResult<Output>;
200
201 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
205
206 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
211
212 fn is_elementwise(&self) -> bool;
220}
221
222#[derive(Clone)]
224pub struct InvocationArgs<'a> {
225 pub inputs: &'a [Input<'a>],
226 pub options: &'a dyn Options,
227}
228
229pub struct UnaryArgs<'a, O: Options> {
231 pub array: &'a dyn Array,
232 pub options: &'a O,
233}
234
235impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
236 type Error = VortexError;
237
238 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
239 if value.inputs.len() != 1 {
240 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
241 }
242 let array = value.inputs[0]
243 .array()
244 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
245 let options =
246 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
247 vortex_err!("Expected options to be of type {}", type_name::<O>())
248 })?;
249 Ok(UnaryArgs { array, options })
250 }
251}
252
253pub struct BinaryArgs<'a, O: Options> {
255 pub lhs: &'a dyn Array,
256 pub rhs: &'a dyn Array,
257 pub options: &'a O,
258}
259
260impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
261 type Error = VortexError;
262
263 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
264 if value.inputs.len() != 2 {
265 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
266 }
267 let lhs = value.inputs[0]
268 .array()
269 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
270 let rhs = value.inputs[1]
271 .array()
272 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
273 let options =
274 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
275 vortex_err!("Expected options to be of type {}", type_name::<O>())
276 })?;
277 Ok(BinaryArgs { lhs, rhs, options })
278 }
279}
280
281pub enum Input<'a> {
283 Scalar(&'a Scalar),
284 Array(&'a dyn Array),
285 Mask(&'a Mask),
286 Builder(&'a mut dyn ArrayBuilder),
287 DType(&'a DType),
288}
289
290impl Debug for Input<'_> {
291 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292 let mut f = f.debug_struct("Input");
293 match self {
294 Input::Scalar(scalar) => f.field("Scalar", scalar),
295 Input::Array(array) => f.field("Array", array),
296 Input::Mask(mask) => f.field("Mask", mask),
297 Input::Builder(builder) => f.field("Builder", &builder.len()),
298 Input::DType(dtype) => f.field("DType", dtype),
299 };
300 f.finish()
301 }
302}
303
304impl<'a> From<&'a dyn Array> for Input<'a> {
305 fn from(value: &'a dyn Array) -> Self {
306 Input::Array(value)
307 }
308}
309
310impl<'a> From<&'a Scalar> for Input<'a> {
311 fn from(value: &'a Scalar) -> Self {
312 Input::Scalar(value)
313 }
314}
315
316impl<'a> From<&'a Mask> for Input<'a> {
317 fn from(value: &'a Mask) -> Self {
318 Input::Mask(value)
319 }
320}
321
322impl<'a> From<&'a DType> for Input<'a> {
323 fn from(value: &'a DType) -> Self {
324 Input::DType(value)
325 }
326}
327
328impl<'a> Input<'a> {
329 pub fn scalar(&self) -> Option<&'a Scalar> {
330 if let Input::Scalar(scalar) = self {
331 Some(*scalar)
332 } else {
333 None
334 }
335 }
336
337 pub fn array(&self) -> Option<&'a dyn Array> {
338 if let Input::Array(array) = self {
339 Some(*array)
340 } else {
341 None
342 }
343 }
344
345 pub fn mask(&self) -> Option<&'a Mask> {
346 if let Input::Mask(mask) = self {
347 Some(*mask)
348 } else {
349 None
350 }
351 }
352
353 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
354 if let Input::Builder(builder) = self {
355 Some(*builder)
356 } else {
357 None
358 }
359 }
360
361 pub fn dtype(&self) -> Option<&'a DType> {
362 if let Input::DType(dtype) = self {
363 Some(*dtype)
364 } else {
365 None
366 }
367 }
368}
369
370#[derive(Debug)]
372pub enum Output {
373 Scalar(Scalar),
374 Array(ArrayRef),
375}
376
377#[allow(clippy::len_without_is_empty)]
378impl Output {
379 pub fn dtype(&self) -> &DType {
380 match self {
381 Output::Scalar(scalar) => scalar.dtype(),
382 Output::Array(array) => array.dtype(),
383 }
384 }
385
386 pub fn len(&self) -> usize {
387 match self {
388 Output::Scalar(_) => 1,
389 Output::Array(array) => array.len(),
390 }
391 }
392
393 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
394 match self {
395 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
396 Output::Scalar(scalar) => Ok(scalar),
397 }
398 }
399
400 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
401 match self {
402 Output::Array(array) => Ok(array),
403 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
404 }
405 }
406}
407
408impl From<ArrayRef> for Output {
409 fn from(value: ArrayRef) -> Self {
410 Output::Array(value)
411 }
412}
413
414impl From<Scalar> for Output {
415 fn from(value: Scalar) -> Self {
416 Output::Scalar(value)
417 }
418}
419
420pub trait Options: 'static {
422 fn as_any(&self) -> &dyn Any;
423}
424
425impl Options for () {
426 fn as_any(&self) -> &dyn Any {
427 self
428 }
429}
430
431pub trait Kernel: 'static + Send + Sync + Debug {
441 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
443}
444
445#[macro_export]
448macro_rules! register_kernel {
449 ($T:expr) => {
450 $crate::aliases::inventory::submit!($T);
451 };
452}