1use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39
40use crate::builders::ArrayBuilder;
41use crate::{Array, ArrayRef};
42
43#[cfg(feature = "arbitrary")]
44mod arbitrary;
45mod between;
46mod boolean;
47mod cast;
48mod compare;
49#[cfg(feature = "test-harness")]
50pub mod conformance;
51mod fill_null;
52mod filter;
53mod invert;
54mod is_constant;
55mod is_sorted;
56mod like;
57mod list_contains;
58mod mask;
59mod min_max;
60mod nan_count;
61mod numeric;
62mod sum;
63mod take;
64
65pub struct ComputeFn {
68 id: ArcRef<str>,
69 vtable: ArcRef<dyn ComputeFnVTable>,
70 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
71}
72
73impl ComputeFn {
74 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
76 Self {
77 id,
78 vtable,
79 kernels: Default::default(),
80 }
81 }
82
83 pub fn id(&self) -> &ArcRef<str> {
85 &self.id
86 }
87
88 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
90 self.kernels.write().push(kernel);
91 }
92
93 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
95 if self.is_elementwise() {
97 if !args
99 .inputs
100 .iter()
101 .filter_map(|input| input.array())
102 .map(|array| array.len())
103 .all_equal()
104 {
105 vortex_bail!(
106 "Compute function {} is elementwise but input arrays have different lengths",
107 self.id
108 );
109 }
110 }
111
112 let expected_dtype = self.vtable.return_dtype(args)?;
113 let expected_len = self.vtable.return_len(args)?;
114
115 let output = self.vtable.invoke(args, &self.kernels.read())?;
116
117 if output.dtype() != &expected_dtype {
118 vortex_bail!(
119 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
120 self.id,
121 output.dtype(),
122 &expected_dtype,
123 args.inputs
124 .iter()
125 .filter_map(|input| input.array())
126 .format_with(",", |array, f| f(&array.display_tree()))
127 );
128 }
129 if output.len() != expected_len {
130 vortex_bail!(
131 "Internal error: compute function {} returned a result of length {} but expected {}",
132 self.id,
133 output.len(),
134 expected_len
135 );
136 }
137
138 Ok(output)
139 }
140
141 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
143 self.vtable.return_dtype(args)
144 }
145
146 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
148 self.vtable.return_len(args)
149 }
150
151 pub fn is_elementwise(&self) -> bool {
153 self.vtable.is_elementwise()
155 }
156
157 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
159 self.kernels.read().to_vec()
160 }
161}
162
163pub trait ComputeFnVTable: 'static + Send + Sync {
165 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
171 -> VortexResult<Output>;
172
173 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
177
178 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
183
184 fn is_elementwise(&self) -> bool;
192}
193
194#[derive(Clone)]
196pub struct InvocationArgs<'a> {
197 pub inputs: &'a [Input<'a>],
198 pub options: &'a dyn Options,
199}
200
201pub struct UnaryArgs<'a, O: Options> {
203 pub array: &'a dyn Array,
204 pub options: &'a O,
205}
206
207impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
208 type Error = VortexError;
209
210 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
211 if value.inputs.len() != 1 {
212 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
213 }
214 let array = value.inputs[0]
215 .array()
216 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
217 let options =
218 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
219 vortex_err!("Expected options to be of type {}", type_name::<O>())
220 })?;
221 Ok(UnaryArgs { array, options })
222 }
223}
224
225pub struct BinaryArgs<'a, O: Options> {
227 pub lhs: &'a dyn Array,
228 pub rhs: &'a dyn Array,
229 pub options: &'a O,
230}
231
232impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
233 type Error = VortexError;
234
235 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
236 if value.inputs.len() != 2 {
237 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
238 }
239 let lhs = value.inputs[0]
240 .array()
241 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
242 let rhs = value.inputs[1]
243 .array()
244 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
245 let options =
246 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
247 vortex_err!("Expected options to be of type {}", type_name::<O>())
248 })?;
249 Ok(BinaryArgs { lhs, rhs, options })
250 }
251}
252
253pub enum Input<'a> {
255 Scalar(&'a Scalar),
256 Array(&'a dyn Array),
257 Mask(&'a Mask),
258 Builder(&'a mut dyn ArrayBuilder),
259 DType(&'a DType),
260}
261
262impl Debug for Input<'_> {
263 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264 let mut f = f.debug_struct("Input");
265 match self {
266 Input::Scalar(scalar) => f.field("Scalar", scalar),
267 Input::Array(array) => f.field("Array", array),
268 Input::Mask(mask) => f.field("Mask", mask),
269 Input::Builder(builder) => f.field("Builder", &builder.len()),
270 Input::DType(dtype) => f.field("DType", dtype),
271 };
272 f.finish()
273 }
274}
275
276impl<'a> From<&'a dyn Array> for Input<'a> {
277 fn from(value: &'a dyn Array) -> Self {
278 Input::Array(value)
279 }
280}
281
282impl<'a> From<&'a Scalar> for Input<'a> {
283 fn from(value: &'a Scalar) -> Self {
284 Input::Scalar(value)
285 }
286}
287
288impl<'a> From<&'a Mask> for Input<'a> {
289 fn from(value: &'a Mask) -> Self {
290 Input::Mask(value)
291 }
292}
293
294impl<'a> From<&'a DType> for Input<'a> {
295 fn from(value: &'a DType) -> Self {
296 Input::DType(value)
297 }
298}
299
300impl<'a> Input<'a> {
301 pub fn scalar(&self) -> Option<&'a Scalar> {
302 match self {
303 Input::Scalar(scalar) => Some(*scalar),
304 _ => None,
305 }
306 }
307
308 pub fn array(&self) -> Option<&'a dyn Array> {
309 match self {
310 Input::Array(array) => Some(*array),
311 _ => None,
312 }
313 }
314
315 pub fn mask(&self) -> Option<&'a Mask> {
316 match self {
317 Input::Mask(mask) => Some(*mask),
318 _ => None,
319 }
320 }
321
322 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
323 match self {
324 Input::Builder(builder) => Some(*builder),
325 _ => None,
326 }
327 }
328
329 pub fn dtype(&self) -> Option<&'a DType> {
330 match self {
331 Input::DType(dtype) => Some(*dtype),
332 _ => None,
333 }
334 }
335}
336
337#[derive(Debug)]
339pub enum Output {
340 Scalar(Scalar),
341 Array(ArrayRef),
342}
343
344#[allow(clippy::len_without_is_empty)]
345impl Output {
346 pub fn dtype(&self) -> &DType {
347 match self {
348 Output::Scalar(scalar) => scalar.dtype(),
349 Output::Array(array) => array.dtype(),
350 }
351 }
352
353 pub fn len(&self) -> usize {
354 match self {
355 Output::Scalar(_) => 1,
356 Output::Array(array) => array.len(),
357 }
358 }
359
360 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
361 match self {
362 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
363 Output::Scalar(scalar) => Ok(scalar),
364 }
365 }
366
367 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
368 match self {
369 Output::Array(array) => Ok(array),
370 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
371 }
372 }
373}
374
375impl From<ArrayRef> for Output {
376 fn from(value: ArrayRef) -> Self {
377 Output::Array(value)
378 }
379}
380
381impl From<Scalar> for Output {
382 fn from(value: Scalar) -> Self {
383 Output::Scalar(value)
384 }
385}
386
387pub trait Options: 'static {
389 fn as_any(&self) -> &dyn Any;
390}
391
392impl Options for () {
393 fn as_any(&self) -> &dyn Any {
394 self
395 }
396}
397
398pub trait Kernel: 'static + Send + Sync + Debug {
408 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
410}
411
412#[macro_export]
415macro_rules! register_kernel {
416 ($T:expr) => {
417 $crate::aliases::inventory::submit!($T);
418 };
419}