vortex_array/compute/
mod.rs1use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use between::*;
19pub use boolean::*;
20pub use cast::*;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24pub use invert::*;
25pub use is_constant::*;
26pub use is_sorted::*;
27use itertools::Itertools;
28pub use like::*;
29pub use list_contains::*;
30pub use mask::*;
31pub use min_max::*;
32pub use nan_count::*;
33pub use numeric::*;
34use parking_lot::RwLock;
35pub use sum::*;
36pub use take::*;
37use vortex_dtype::DType;
38use vortex_error::VortexError;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_err;
42use vortex_mask::Mask;
43use vortex_scalar::Scalar;
44pub use zip::*;
45
46use crate::Array;
47use crate::ArrayRef;
48use crate::builders::ArrayBuilder;
49
50#[cfg(feature = "arbitrary")]
51mod arbitrary;
52mod between;
53mod boolean;
54mod cast;
55mod compare;
56#[cfg(feature = "test-harness")]
57pub mod conformance;
58mod fill_null;
59mod filter;
60mod invert;
61mod is_constant;
62mod is_sorted;
63mod like;
64mod list_contains;
65mod mask;
66mod min_max;
67mod nan_count;
68mod numeric;
69mod sum;
70mod take;
71mod zip;
72
73pub struct ComputeFn {
76 id: ArcRef<str>,
77 vtable: ArcRef<dyn ComputeFnVTable>,
78 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
79}
80
81pub fn warm_up_vtables() {
85 crate::arrow::warm_up_vtable();
86 #[allow(unused_qualifications)]
87 between::warm_up_vtable();
88 boolean::warm_up_vtable();
89 cast::warm_up_vtable();
90 compare::warm_up_vtable();
91 fill_null::warm_up_vtable();
92 filter::warm_up_vtable();
93 invert::warm_up_vtable();
94 is_constant::warm_up_vtable();
95 is_sorted::warm_up_vtable();
96 like::warm_up_vtable();
97 list_contains::warm_up_vtable();
98 mask::warm_up_vtable();
99 min_max::warm_up_vtable();
100 nan_count::warm_up_vtable();
101 numeric::warm_up_vtable();
102 sum::warm_up_vtable();
103 take::warm_up_vtable();
104 zip::warm_up_vtable();
105}
106
107impl ComputeFn {
108 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
110 Self {
111 id,
112 vtable,
113 kernels: Default::default(),
114 }
115 }
116
117 pub fn id(&self) -> &ArcRef<str> {
119 &self.id
120 }
121
122 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
124 self.kernels.write().push(kernel);
125 }
126
127 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
129 if self.is_elementwise() {
131 if !args
133 .inputs
134 .iter()
135 .filter_map(|input| input.array())
136 .map(|array| array.len())
137 .all_equal()
138 {
139 vortex_bail!(
140 "Compute function {} is elementwise but input arrays have different lengths",
141 self.id
142 );
143 }
144 }
145
146 let expected_dtype = self.vtable.return_dtype(args)?;
147 let expected_len = self.vtable.return_len(args)?;
148
149 let output = self.vtable.invoke(args, &self.kernels.read())?;
150
151 if output.dtype() != &expected_dtype {
152 vortex_bail!(
153 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
154 self.id,
155 output.dtype(),
156 &expected_dtype,
157 args.inputs
158 .iter()
159 .filter_map(|input| input.array())
160 .format_with(",", |array, f| f(&array.display_tree()))
161 );
162 }
163 if output.len() != expected_len {
164 vortex_bail!(
165 "Internal error: compute function {} returned a result of length {} but expected {}",
166 self.id,
167 output.len(),
168 expected_len
169 );
170 }
171
172 Ok(output)
173 }
174
175 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
177 self.vtable.return_dtype(args)
178 }
179
180 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
182 self.vtable.return_len(args)
183 }
184
185 pub fn is_elementwise(&self) -> bool {
187 self.vtable.is_elementwise()
189 }
190
191 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
193 self.kernels.read().to_vec()
194 }
195}
196
197pub trait ComputeFnVTable: 'static + Send + Sync {
199 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
205 -> VortexResult<Output>;
206
207 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
211
212 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
217
218 fn is_elementwise(&self) -> bool;
226}
227
228#[derive(Clone)]
230pub struct InvocationArgs<'a> {
231 pub inputs: &'a [Input<'a>],
232 pub options: &'a dyn Options,
233}
234
235pub struct UnaryArgs<'a, O: Options> {
237 pub array: &'a dyn Array,
238 pub options: &'a O,
239}
240
241impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
242 type Error = VortexError;
243
244 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
245 if value.inputs.len() != 1 {
246 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
247 }
248 let array = value.inputs[0]
249 .array()
250 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
251 let options =
252 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
253 vortex_err!("Expected options to be of type {}", type_name::<O>())
254 })?;
255 Ok(UnaryArgs { array, options })
256 }
257}
258
259pub struct BinaryArgs<'a, O: Options> {
261 pub lhs: &'a dyn Array,
262 pub rhs: &'a dyn Array,
263 pub options: &'a O,
264}
265
266impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
267 type Error = VortexError;
268
269 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
270 if value.inputs.len() != 2 {
271 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
272 }
273 let lhs = value.inputs[0]
274 .array()
275 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
276 let rhs = value.inputs[1]
277 .array()
278 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
279 let options =
280 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
281 vortex_err!("Expected options to be of type {}", type_name::<O>())
282 })?;
283 Ok(BinaryArgs { lhs, rhs, options })
284 }
285}
286
287pub enum Input<'a> {
289 Scalar(&'a Scalar),
290 Array(&'a dyn Array),
291 Mask(&'a Mask),
292 Builder(&'a mut dyn ArrayBuilder),
293 DType(&'a DType),
294}
295
296impl Debug for Input<'_> {
297 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298 let mut f = f.debug_struct("Input");
299 match self {
300 Input::Scalar(scalar) => f.field("Scalar", scalar),
301 Input::Array(array) => f.field("Array", array),
302 Input::Mask(mask) => f.field("Mask", mask),
303 Input::Builder(builder) => f.field("Builder", &builder.len()),
304 Input::DType(dtype) => f.field("DType", dtype),
305 };
306 f.finish()
307 }
308}
309
310impl<'a> From<&'a dyn Array> for Input<'a> {
311 fn from(value: &'a dyn Array) -> Self {
312 Input::Array(value)
313 }
314}
315
316impl<'a> From<&'a Scalar> for Input<'a> {
317 fn from(value: &'a Scalar) -> Self {
318 Input::Scalar(value)
319 }
320}
321
322impl<'a> From<&'a Mask> for Input<'a> {
323 fn from(value: &'a Mask) -> Self {
324 Input::Mask(value)
325 }
326}
327
328impl<'a> From<&'a DType> for Input<'a> {
329 fn from(value: &'a DType) -> Self {
330 Input::DType(value)
331 }
332}
333
334impl<'a> Input<'a> {
335 pub fn scalar(&self) -> Option<&'a Scalar> {
336 if let Input::Scalar(scalar) = self {
337 Some(*scalar)
338 } else {
339 None
340 }
341 }
342
343 pub fn array(&self) -> Option<&'a dyn Array> {
344 if let Input::Array(array) = self {
345 Some(*array)
346 } else {
347 None
348 }
349 }
350
351 pub fn mask(&self) -> Option<&'a Mask> {
352 if let Input::Mask(mask) = self {
353 Some(*mask)
354 } else {
355 None
356 }
357 }
358
359 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
360 if let Input::Builder(builder) = self {
361 Some(*builder)
362 } else {
363 None
364 }
365 }
366
367 pub fn dtype(&self) -> Option<&'a DType> {
368 if let Input::DType(dtype) = self {
369 Some(*dtype)
370 } else {
371 None
372 }
373 }
374}
375
376#[derive(Debug)]
378pub enum Output {
379 Scalar(Scalar),
380 Array(ArrayRef),
381}
382
383#[expect(
384 clippy::len_without_is_empty,
385 reason = "Output is always non-empty (scalar has len 1)"
386)]
387impl Output {
388 pub fn dtype(&self) -> &DType {
389 match self {
390 Output::Scalar(scalar) => scalar.dtype(),
391 Output::Array(array) => array.dtype(),
392 }
393 }
394
395 pub fn len(&self) -> usize {
396 match self {
397 Output::Scalar(_) => 1,
398 Output::Array(array) => array.len(),
399 }
400 }
401
402 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
403 match self {
404 Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
405 Output::Scalar(scalar) => Ok(scalar),
406 }
407 }
408
409 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
410 match self {
411 Output::Array(array) => Ok(array),
412 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
413 }
414 }
415}
416
417impl From<ArrayRef> for Output {
418 fn from(value: ArrayRef) -> Self {
419 Output::Array(value)
420 }
421}
422
423impl From<Scalar> for Output {
424 fn from(value: Scalar) -> Self {
425 Output::Scalar(value)
426 }
427}
428
429pub trait Options: 'static {
431 fn as_any(&self) -> &dyn Any;
432}
433
434impl Options for () {
435 fn as_any(&self) -> &dyn Any {
436 self
437 }
438}
439
440pub trait Kernel: 'static + Send + Sync + Debug {
450 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
452}
453
454#[macro_export]
457macro_rules! register_kernel {
458 ($T:expr) => {
459 $crate::aliases::inventory::submit!($T);
460 };
461}