1use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46mod between;
47mod boolean;
48mod cast;
49mod compare;
50#[cfg(feature = "test-harness")]
51pub mod conformance;
52mod fill_null;
53mod filter;
54mod invert;
55mod is_constant;
56mod is_sorted;
57mod like;
58mod list_contains;
59mod mask;
60mod min_max;
61mod nan_count;
62mod numeric;
63mod sum;
64mod take;
65mod zip;
66
67pub struct ComputeFn {
70 id: ArcRef<str>,
71 vtable: ArcRef<dyn ComputeFnVTable>,
72 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
73}
74
75impl ComputeFn {
76 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
78 Self {
79 id,
80 vtable,
81 kernels: Default::default(),
82 }
83 }
84
85 pub fn id(&self) -> &ArcRef<str> {
87 &self.id
88 }
89
90 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
92 self.kernels.write().push(kernel);
93 }
94
95 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
97 if self.is_elementwise() {
99 if !args
101 .inputs
102 .iter()
103 .filter_map(|input| input.array())
104 .map(|array| array.len())
105 .all_equal()
106 {
107 vortex_bail!(
108 "Compute function {} is elementwise but input arrays have different lengths",
109 self.id
110 );
111 }
112 }
113
114 let expected_dtype = self.vtable.return_dtype(args)?;
115 let expected_len = self.vtable.return_len(args)?;
116
117 let output = self.vtable.invoke(args, &self.kernels.read())?;
118
119 if output.dtype() != &expected_dtype {
120 vortex_bail!(
121 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
122 self.id,
123 output.dtype(),
124 &expected_dtype,
125 args.inputs
126 .iter()
127 .filter_map(|input| input.array())
128 .format_with(",", |array, f| f(&array.display_tree()))
129 );
130 }
131 if output.len() != expected_len {
132 vortex_bail!(
133 "Internal error: compute function {} returned a result of length {} but expected {}",
134 self.id,
135 output.len(),
136 expected_len
137 );
138 }
139
140 Ok(output)
141 }
142
143 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
145 self.vtable.return_dtype(args)
146 }
147
148 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
150 self.vtable.return_len(args)
151 }
152
153 pub fn is_elementwise(&self) -> bool {
155 self.vtable.is_elementwise()
157 }
158
159 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
161 self.kernels.read().to_vec()
162 }
163}
164
165pub trait ComputeFnVTable: 'static + Send + Sync {
167 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
173 -> VortexResult<Output>;
174
175 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
179
180 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
185
186 fn is_elementwise(&self) -> bool;
194}
195
196#[derive(Clone)]
198pub struct InvocationArgs<'a> {
199 pub inputs: &'a [Input<'a>],
200 pub options: &'a dyn Options,
201}
202
203pub struct UnaryArgs<'a, O: Options> {
205 pub array: &'a dyn Array,
206 pub options: &'a O,
207}
208
209impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
210 type Error = VortexError;
211
212 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
213 if value.inputs.len() != 1 {
214 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
215 }
216 let array = value.inputs[0]
217 .array()
218 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
219 let options =
220 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
221 vortex_err!("Expected options to be of type {}", type_name::<O>())
222 })?;
223 Ok(UnaryArgs { array, options })
224 }
225}
226
227pub struct BinaryArgs<'a, O: Options> {
229 pub lhs: &'a dyn Array,
230 pub rhs: &'a dyn Array,
231 pub options: &'a O,
232}
233
234impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
235 type Error = VortexError;
236
237 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
238 if value.inputs.len() != 2 {
239 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
240 }
241 let lhs = value.inputs[0]
242 .array()
243 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
244 let rhs = value.inputs[1]
245 .array()
246 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
247 let options =
248 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
249 vortex_err!("Expected options to be of type {}", type_name::<O>())
250 })?;
251 Ok(BinaryArgs { lhs, rhs, options })
252 }
253}
254
255pub enum Input<'a> {
257 Scalar(&'a Scalar),
258 Array(&'a dyn Array),
259 Mask(&'a Mask),
260 Builder(&'a mut dyn ArrayBuilder),
261 DType(&'a DType),
262}
263
264impl Debug for Input<'_> {
265 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266 let mut f = f.debug_struct("Input");
267 match self {
268 Input::Scalar(scalar) => f.field("Scalar", scalar),
269 Input::Array(array) => f.field("Array", array),
270 Input::Mask(mask) => f.field("Mask", mask),
271 Input::Builder(builder) => f.field("Builder", &builder.len()),
272 Input::DType(dtype) => f.field("DType", dtype),
273 };
274 f.finish()
275 }
276}
277
278impl<'a> From<&'a dyn Array> for Input<'a> {
279 fn from(value: &'a dyn Array) -> Self {
280 Input::Array(value)
281 }
282}
283
284impl<'a> From<&'a Scalar> for Input<'a> {
285 fn from(value: &'a Scalar) -> Self {
286 Input::Scalar(value)
287 }
288}
289
290impl<'a> From<&'a Mask> for Input<'a> {
291 fn from(value: &'a Mask) -> Self {
292 Input::Mask(value)
293 }
294}
295
296impl<'a> From<&'a DType> for Input<'a> {
297 fn from(value: &'a DType) -> Self {
298 Input::DType(value)
299 }
300}
301
302impl<'a> Input<'a> {
303 pub fn scalar(&self) -> Option<&'a Scalar> {
304 match self {
305 Input::Scalar(scalar) => Some(*scalar),
306 _ => None,
307 }
308 }
309
310 pub fn array(&self) -> Option<&'a dyn Array> {
311 match self {
312 Input::Array(array) => Some(*array),
313 _ => None,
314 }
315 }
316
317 pub fn mask(&self) -> Option<&'a Mask> {
318 match self {
319 Input::Mask(mask) => Some(*mask),
320 _ => None,
321 }
322 }
323
324 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
325 match self {
326 Input::Builder(builder) => Some(*builder),
327 _ => None,
328 }
329 }
330
331 pub fn dtype(&self) -> Option<&'a DType> {
332 match self {
333 Input::DType(dtype) => Some(*dtype),
334 _ => None,
335 }
336 }
337}
338
339#[derive(Debug)]
341pub enum Output {
342 Scalar(Scalar),
343 Array(ArrayRef),
344}
345
346#[allow(clippy::len_without_is_empty)]
347impl Output {
348 pub fn dtype(&self) -> &DType {
349 match self {
350 Output::Scalar(scalar) => scalar.dtype(),
351 Output::Array(array) => array.dtype(),
352 }
353 }
354
355 pub fn len(&self) -> usize {
356 match self {
357 Output::Scalar(_) => 1,
358 Output::Array(array) => array.len(),
359 }
360 }
361
362 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
363 match self {
364 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
365 Output::Scalar(scalar) => Ok(scalar),
366 }
367 }
368
369 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
370 match self {
371 Output::Array(array) => Ok(array),
372 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
373 }
374 }
375}
376
377impl From<ArrayRef> for Output {
378 fn from(value: ArrayRef) -> Self {
379 Output::Array(value)
380 }
381}
382
383impl From<Scalar> for Output {
384 fn from(value: Scalar) -> Self {
385 Output::Scalar(value)
386 }
387}
388
389pub trait Options: 'static {
391 fn as_any(&self) -> &dyn Any;
392}
393
394impl Options for () {
395 fn as_any(&self) -> &dyn Any {
396 self
397 }
398}
399
400pub trait Kernel: 'static + Send + Sync + Debug {
410 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
412}
413
414#[macro_export]
417macro_rules! register_kernel {
418 ($T:expr) => {
419 $crate::aliases::inventory::submit!($T);
420 };
421}