vortex_array/compute/
mod.rs1use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use boolean::*;
19#[expect(deprecated)]
20pub use cast::cast;
21pub use compare::*;
22pub use fill_null::*;
23pub use filter::*;
24#[expect(deprecated)]
25pub use invert::invert;
26pub use is_constant::*;
27pub use is_sorted::*;
28use itertools::Itertools;
29#[expect(deprecated)]
30pub use list_contains::list_contains;
31pub use mask::*;
32pub use min_max::*;
33pub use nan_count::*;
34pub use numeric::*;
35use parking_lot::RwLock;
36pub use sum::*;
37use vortex_error::VortexError;
38use vortex_error::VortexResult;
39use vortex_error::vortex_bail;
40use vortex_error::vortex_err;
41use vortex_mask::Mask;
42pub use zip::*;
43
44use crate::Array;
45use crate::ArrayRef;
46use crate::builders::ArrayBuilder;
47use crate::dtype::DType;
48use crate::scalar::Scalar;
49
50#[cfg(feature = "arbitrary")]
51mod arbitrary;
52mod boolean;
53mod cast;
54mod compare;
55#[cfg(feature = "_test-harness")]
56pub mod conformance;
57mod fill_null;
58mod filter;
59mod invert;
60mod is_constant;
61mod is_sorted;
62mod list_contains;
63mod mask;
64mod min_max;
65mod nan_count;
66mod numeric;
67mod sum;
68mod zip;
69
70pub struct ComputeFn {
73 id: ArcRef<str>,
74 vtable: ArcRef<dyn ComputeFnVTable>,
75 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
76}
77
78pub fn warm_up_vtables() {
82 #[allow(unused_qualifications)]
83 is_constant::warm_up_vtable();
84 is_sorted::warm_up_vtable();
85 min_max::warm_up_vtable();
86 nan_count::warm_up_vtable();
87 sum::warm_up_vtable();
88}
89
90impl ComputeFn {
91 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
93 Self {
94 id,
95 vtable,
96 kernels: Default::default(),
97 }
98 }
99
100 pub fn id(&self) -> &ArcRef<str> {
102 &self.id
103 }
104
105 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
107 self.kernels.write().push(kernel);
108 }
109
110 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
112 if self.is_elementwise() {
114 if !args
116 .inputs
117 .iter()
118 .filter_map(|input| input.array())
119 .map(|array| array.len())
120 .all_equal()
121 {
122 vortex_bail!(
123 "Compute function {} is elementwise but input arrays have different lengths",
124 self.id
125 );
126 }
127 }
128
129 let expected_dtype = self.vtable.return_dtype(args)?;
130 let expected_len = self.vtable.return_len(args)?;
131
132 let output = self.vtable.invoke(args, &self.kernels.read())?;
133
134 if output.dtype() != &expected_dtype {
135 vortex_bail!(
136 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
137 self.id,
138 output.dtype(),
139 &expected_dtype,
140 args.inputs
141 .iter()
142 .filter_map(|input| input.array())
143 .format_with(",", |array, f| f(&array.encoding_id()))
144 );
145 }
146 if output.len() != expected_len {
147 vortex_bail!(
148 "Internal error: compute function {} returned a result of length {} but expected {}",
149 self.id,
150 output.len(),
151 expected_len
152 );
153 }
154
155 Ok(output)
156 }
157
158 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
160 self.vtable.return_dtype(args)
161 }
162
163 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
165 self.vtable.return_len(args)
166 }
167
168 pub fn is_elementwise(&self) -> bool {
170 self.vtable.is_elementwise()
172 }
173
174 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
176 self.kernels.read().to_vec()
177 }
178}
179
180pub trait ComputeFnVTable: 'static + Send + Sync {
182 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
188 -> VortexResult<Output>;
189
190 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
194
195 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
200
201 fn is_elementwise(&self) -> bool;
209}
210
211#[derive(Clone)]
213pub struct InvocationArgs<'a> {
214 pub inputs: &'a [Input<'a>],
215 pub options: &'a dyn Options,
216}
217
218pub struct UnaryArgs<'a, O: Options> {
220 pub array: &'a dyn Array,
221 pub options: &'a O,
222}
223
224impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
225 type Error = VortexError;
226
227 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
228 if value.inputs.len() != 1 {
229 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
230 }
231 let array = value.inputs[0]
232 .array()
233 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
234 let options =
235 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
236 vortex_err!("Expected options to be of type {}", type_name::<O>())
237 })?;
238 Ok(UnaryArgs { array, options })
239 }
240}
241
242pub struct BinaryArgs<'a, O: Options> {
244 pub lhs: &'a dyn Array,
245 pub rhs: &'a dyn Array,
246 pub options: &'a O,
247}
248
249impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
250 type Error = VortexError;
251
252 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
253 if value.inputs.len() != 2 {
254 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
255 }
256 let lhs = value.inputs[0]
257 .array()
258 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
259 let rhs = value.inputs[1]
260 .array()
261 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
262 let options =
263 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
264 vortex_err!("Expected options to be of type {}", type_name::<O>())
265 })?;
266 Ok(BinaryArgs { lhs, rhs, options })
267 }
268}
269
270pub enum Input<'a> {
272 Scalar(&'a Scalar),
273 Array(&'a dyn Array),
274 Mask(&'a Mask),
275 Builder(&'a mut dyn ArrayBuilder),
276 DType(&'a DType),
277}
278
279impl Debug for Input<'_> {
280 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281 let mut f = f.debug_struct("Input");
282 match self {
283 Input::Scalar(scalar) => f.field("Scalar", scalar),
284 Input::Array(array) => f.field("Array", array),
285 Input::Mask(mask) => f.field("Mask", mask),
286 Input::Builder(builder) => f.field("Builder", &builder.len()),
287 Input::DType(dtype) => f.field("DType", dtype),
288 };
289 f.finish()
290 }
291}
292
293impl<'a> From<&'a dyn Array> for Input<'a> {
294 fn from(value: &'a dyn Array) -> Self {
295 Input::Array(value)
296 }
297}
298
299impl<'a> From<&'a Scalar> for Input<'a> {
300 fn from(value: &'a Scalar) -> Self {
301 Input::Scalar(value)
302 }
303}
304
305impl<'a> From<&'a Mask> for Input<'a> {
306 fn from(value: &'a Mask) -> Self {
307 Input::Mask(value)
308 }
309}
310
311impl<'a> From<&'a DType> for Input<'a> {
312 fn from(value: &'a DType) -> Self {
313 Input::DType(value)
314 }
315}
316
317impl<'a> Input<'a> {
318 pub fn scalar(&self) -> Option<&'a Scalar> {
319 if let Input::Scalar(scalar) = self {
320 Some(*scalar)
321 } else {
322 None
323 }
324 }
325
326 pub fn array(&self) -> Option<&'a dyn Array> {
327 if let Input::Array(array) = self {
328 Some(*array)
329 } else {
330 None
331 }
332 }
333
334 pub fn mask(&self) -> Option<&'a Mask> {
335 if let Input::Mask(mask) = self {
336 Some(*mask)
337 } else {
338 None
339 }
340 }
341
342 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
343 if let Input::Builder(builder) = self {
344 Some(*builder)
345 } else {
346 None
347 }
348 }
349
350 pub fn dtype(&self) -> Option<&'a DType> {
351 if let Input::DType(dtype) = self {
352 Some(*dtype)
353 } else {
354 None
355 }
356 }
357}
358
359#[derive(Debug)]
361pub enum Output {
362 Scalar(Scalar),
363 Array(ArrayRef),
364}
365
366#[expect(
367 clippy::len_without_is_empty,
368 reason = "Output is always non-empty (scalar has len 1)"
369)]
370impl Output {
371 pub fn dtype(&self) -> &DType {
372 match self {
373 Output::Scalar(scalar) => scalar.dtype(),
374 Output::Array(array) => array.dtype(),
375 }
376 }
377
378 pub fn len(&self) -> usize {
379 match self {
380 Output::Scalar(_) => 1,
381 Output::Array(array) => array.len(),
382 }
383 }
384
385 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
386 match self {
387 Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
388 Output::Scalar(scalar) => Ok(scalar),
389 }
390 }
391
392 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
393 match self {
394 Output::Array(array) => Ok(array),
395 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
396 }
397 }
398}
399
400impl From<ArrayRef> for Output {
401 fn from(value: ArrayRef) -> Self {
402 Output::Array(value)
403 }
404}
405
406impl From<Scalar> for Output {
407 fn from(value: Scalar) -> Self {
408 Output::Scalar(value)
409 }
410}
411
412pub trait Options: 'static {
414 fn as_any(&self) -> &dyn Any;
415}
416
417impl Options for () {
418 fn as_any(&self) -> &dyn Any {
419 self
420 }
421}
422
423pub trait Kernel: 'static + Send + Sync + Debug {
433 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
435}
436
437#[macro_export]
440macro_rules! register_kernel {
441 ($T:expr) => {
442 $crate::aliases::inventory::submit!($T);
443 };
444}