vortex_array/compute/
mod.rs1use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use boolean::*;
19#[expect(deprecated)]
20pub use cast::cast;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24#[expect(deprecated)]
25pub use invert::invert;
26pub use is_constant::*;
27pub use is_sorted::*;
28use itertools::Itertools;
29pub use list_contains::*;
30pub use mask::*;
31pub use min_max::*;
32pub use nan_count::*;
33pub use numeric::*;
34use parking_lot::RwLock;
35pub use sum::*;
36use vortex_dtype::DType;
37use vortex_error::VortexError;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_err;
41use vortex_mask::Mask;
42pub use zip::*;
43
44use crate::Array;
45use crate::ArrayRef;
46use crate::builders::ArrayBuilder;
47pub use crate::expr::BetweenExecuteAdaptor;
48pub use crate::expr::BetweenKernel;
49pub use crate::expr::BetweenReduce;
50pub use crate::expr::BetweenReduceAdaptor;
51pub use crate::expr::CastExecuteAdaptor;
52pub use crate::expr::CastKernel;
53pub use crate::expr::CastReduce;
54pub use crate::expr::CastReduceAdaptor;
55pub use crate::expr::FillNullExecuteAdaptor;
56pub use crate::expr::FillNullKernel;
57pub use crate::expr::FillNullReduce;
58pub use crate::expr::FillNullReduceAdaptor;
59pub use crate::expr::MaskExecuteAdaptor;
60pub use crate::expr::MaskKernel;
61pub use crate::expr::MaskReduce;
62pub use crate::expr::MaskReduceAdaptor;
63pub use crate::expr::NotExecuteAdaptor;
64pub use crate::expr::NotKernel;
65pub use crate::expr::NotReduce;
66pub use crate::expr::NotReduceAdaptor;
67use crate::scalar::Scalar;
68
69#[cfg(feature = "arbitrary")]
70mod arbitrary;
71mod boolean;
72mod cast;
73mod compare;
74#[cfg(feature = "_test-harness")]
75pub mod conformance;
76mod fill_null;
77mod filter;
78mod invert;
79mod is_constant;
80mod is_sorted;
81mod list_contains;
82mod mask;
83mod min_max;
84mod nan_count;
85mod numeric;
86mod sum;
87mod zip;
88
89pub struct ComputeFn {
92 id: ArcRef<str>,
93 vtable: ArcRef<dyn ComputeFnVTable>,
94 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
95}
96
97pub fn warm_up_vtables() {
101 #[allow(unused_qualifications)]
102 is_constant::warm_up_vtable();
103 is_sorted::warm_up_vtable();
104 list_contains::warm_up_vtable();
105 min_max::warm_up_vtable();
106 nan_count::warm_up_vtable();
107 sum::warm_up_vtable();
108}
109
110impl ComputeFn {
111 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
113 Self {
114 id,
115 vtable,
116 kernels: Default::default(),
117 }
118 }
119
120 pub fn id(&self) -> &ArcRef<str> {
122 &self.id
123 }
124
125 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
127 self.kernels.write().push(kernel);
128 }
129
130 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
132 if self.is_elementwise() {
134 if !args
136 .inputs
137 .iter()
138 .filter_map(|input| input.array())
139 .map(|array| array.len())
140 .all_equal()
141 {
142 vortex_bail!(
143 "Compute function {} is elementwise but input arrays have different lengths",
144 self.id
145 );
146 }
147 }
148
149 let expected_dtype = self.vtable.return_dtype(args)?;
150 let expected_len = self.vtable.return_len(args)?;
151
152 let output = self.vtable.invoke(args, &self.kernels.read())?;
153
154 if output.dtype() != &expected_dtype {
155 vortex_bail!(
156 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
157 self.id,
158 output.dtype(),
159 &expected_dtype,
160 args.inputs
161 .iter()
162 .filter_map(|input| input.array())
163 .format_with(",", |array, f| f(&array.encoding_id()))
164 );
165 }
166 if output.len() != expected_len {
167 vortex_bail!(
168 "Internal error: compute function {} returned a result of length {} but expected {}",
169 self.id,
170 output.len(),
171 expected_len
172 );
173 }
174
175 Ok(output)
176 }
177
178 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
180 self.vtable.return_dtype(args)
181 }
182
183 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
185 self.vtable.return_len(args)
186 }
187
188 pub fn is_elementwise(&self) -> bool {
190 self.vtable.is_elementwise()
192 }
193
194 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
196 self.kernels.read().to_vec()
197 }
198}
199
200pub trait ComputeFnVTable: 'static + Send + Sync {
202 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
208 -> VortexResult<Output>;
209
210 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
214
215 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
220
221 fn is_elementwise(&self) -> bool;
229}
230
231#[derive(Clone)]
233pub struct InvocationArgs<'a> {
234 pub inputs: &'a [Input<'a>],
235 pub options: &'a dyn Options,
236}
237
238pub struct UnaryArgs<'a, O: Options> {
240 pub array: &'a dyn Array,
241 pub options: &'a O,
242}
243
244impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
245 type Error = VortexError;
246
247 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
248 if value.inputs.len() != 1 {
249 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
250 }
251 let array = value.inputs[0]
252 .array()
253 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
254 let options =
255 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
256 vortex_err!("Expected options to be of type {}", type_name::<O>())
257 })?;
258 Ok(UnaryArgs { array, options })
259 }
260}
261
262pub struct BinaryArgs<'a, O: Options> {
264 pub lhs: &'a dyn Array,
265 pub rhs: &'a dyn Array,
266 pub options: &'a O,
267}
268
269impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
270 type Error = VortexError;
271
272 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
273 if value.inputs.len() != 2 {
274 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
275 }
276 let lhs = value.inputs[0]
277 .array()
278 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
279 let rhs = value.inputs[1]
280 .array()
281 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
282 let options =
283 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
284 vortex_err!("Expected options to be of type {}", type_name::<O>())
285 })?;
286 Ok(BinaryArgs { lhs, rhs, options })
287 }
288}
289
290pub enum Input<'a> {
292 Scalar(&'a Scalar),
293 Array(&'a dyn Array),
294 Mask(&'a Mask),
295 Builder(&'a mut dyn ArrayBuilder),
296 DType(&'a DType),
297}
298
299impl Debug for Input<'_> {
300 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
301 let mut f = f.debug_struct("Input");
302 match self {
303 Input::Scalar(scalar) => f.field("Scalar", scalar),
304 Input::Array(array) => f.field("Array", array),
305 Input::Mask(mask) => f.field("Mask", mask),
306 Input::Builder(builder) => f.field("Builder", &builder.len()),
307 Input::DType(dtype) => f.field("DType", dtype),
308 };
309 f.finish()
310 }
311}
312
313impl<'a> From<&'a dyn Array> for Input<'a> {
314 fn from(value: &'a dyn Array) -> Self {
315 Input::Array(value)
316 }
317}
318
319impl<'a> From<&'a Scalar> for Input<'a> {
320 fn from(value: &'a Scalar) -> Self {
321 Input::Scalar(value)
322 }
323}
324
325impl<'a> From<&'a Mask> for Input<'a> {
326 fn from(value: &'a Mask) -> Self {
327 Input::Mask(value)
328 }
329}
330
331impl<'a> From<&'a DType> for Input<'a> {
332 fn from(value: &'a DType) -> Self {
333 Input::DType(value)
334 }
335}
336
337impl<'a> Input<'a> {
338 pub fn scalar(&self) -> Option<&'a Scalar> {
339 if let Input::Scalar(scalar) = self {
340 Some(*scalar)
341 } else {
342 None
343 }
344 }
345
346 pub fn array(&self) -> Option<&'a dyn Array> {
347 if let Input::Array(array) = self {
348 Some(*array)
349 } else {
350 None
351 }
352 }
353
354 pub fn mask(&self) -> Option<&'a Mask> {
355 if let Input::Mask(mask) = self {
356 Some(*mask)
357 } else {
358 None
359 }
360 }
361
362 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
363 if let Input::Builder(builder) = self {
364 Some(*builder)
365 } else {
366 None
367 }
368 }
369
370 pub fn dtype(&self) -> Option<&'a DType> {
371 if let Input::DType(dtype) = self {
372 Some(*dtype)
373 } else {
374 None
375 }
376 }
377}
378
379#[derive(Debug)]
381pub enum Output {
382 Scalar(Scalar),
383 Array(ArrayRef),
384}
385
386#[expect(
387 clippy::len_without_is_empty,
388 reason = "Output is always non-empty (scalar has len 1)"
389)]
390impl Output {
391 pub fn dtype(&self) -> &DType {
392 match self {
393 Output::Scalar(scalar) => scalar.dtype(),
394 Output::Array(array) => array.dtype(),
395 }
396 }
397
398 pub fn len(&self) -> usize {
399 match self {
400 Output::Scalar(_) => 1,
401 Output::Array(array) => array.len(),
402 }
403 }
404
405 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
406 match self {
407 Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
408 Output::Scalar(scalar) => Ok(scalar),
409 }
410 }
411
412 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
413 match self {
414 Output::Array(array) => Ok(array),
415 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
416 }
417 }
418}
419
420impl From<ArrayRef> for Output {
421 fn from(value: ArrayRef) -> Self {
422 Output::Array(value)
423 }
424}
425
426impl From<Scalar> for Output {
427 fn from(value: Scalar) -> Self {
428 Output::Scalar(value)
429 }
430}
431
432pub trait Options: 'static {
434 fn as_any(&self) -> &dyn Any;
435}
436
437impl Options for () {
438 fn as_any(&self) -> &dyn Any {
439 self
440 }
441}
442
443pub trait Kernel: 'static + Send + Sync + Debug {
453 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
455}
456
457#[macro_export]
460macro_rules! register_kernel {
461 ($T:expr) => {
462 $crate::aliases::inventory::submit!($T);
463 };
464}