1use std::any::{Any, type_name};
13use std::fmt::{Debug, Formatter};
14
15use arcref::ArcRef;
16pub use between::*;
17pub use boolean::*;
18pub use cast::*;
19pub use compare::*;
20pub use fill_null::*;
21pub use filter::*;
22pub use invert::*;
23pub use is_constant::*;
24pub use is_sorted::*;
25use itertools::Itertools;
26pub use like::*;
27pub use list_contains::*;
28pub use mask::*;
29pub use min_max::*;
30pub use nan_count::*;
31pub use numeric::*;
32use parking_lot::RwLock;
33pub use sum::*;
34pub use take::*;
35use vortex_dtype::DType;
36use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37use vortex_mask::Mask;
38use vortex_scalar::Scalar;
39pub use zip::*;
40
41use crate::builders::ArrayBuilder;
42use crate::{Array, ArrayRef};
43
44#[cfg(feature = "arbitrary")]
45mod arbitrary;
46mod between;
47mod boolean;
48mod cast;
49mod compare;
50#[cfg(feature = "test-harness")]
51pub mod conformance;
52mod fill_null;
53mod filter;
54mod invert;
55mod is_constant;
56mod is_sorted;
57mod like;
58mod list_contains;
59mod mask;
60mod min_max;
61mod nan_count;
62mod numeric;
63mod sum;
64mod take;
65mod zip;
66
67pub struct ComputeFn {
70 id: ArcRef<str>,
71 vtable: ArcRef<dyn ComputeFnVTable>,
72 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
73}
74
75impl ComputeFn {
76 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
78 Self {
79 id,
80 vtable,
81 kernels: Default::default(),
82 }
83 }
84
85 pub fn id(&self) -> &ArcRef<str> {
87 &self.id
88 }
89
90 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
92 self.kernels.write().push(kernel);
93 }
94
95 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
97 if self.is_elementwise() {
99 if !args
101 .inputs
102 .iter()
103 .filter_map(|input| input.array())
104 .map(|array| array.len())
105 .all_equal()
106 {
107 vortex_bail!(
108 "Compute function {} is elementwise but input arrays have different lengths",
109 self.id
110 );
111 }
112 }
113
114 let expected_dtype = self.vtable.return_dtype(args)?;
115 let expected_len = self.vtable.return_len(args)?;
116
117 let output = self.vtable.invoke(args, &self.kernels.read())?;
118
119 if output.dtype() != &expected_dtype {
120 vortex_bail!(
121 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
122 self.id,
123 output.dtype(),
124 &expected_dtype,
125 args.inputs
126 .iter()
127 .filter_map(|input| input.array())
128 .format_with(",", |array, f| f(&array.display_tree()))
129 );
130 }
131 if output.len() != expected_len {
132 vortex_bail!(
133 "Internal error: compute function {} returned a result of length {} but expected {}",
134 self.id,
135 output.len(),
136 expected_len
137 );
138 }
139
140 Ok(output)
141 }
142
143 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
145 self.vtable.return_dtype(args)
146 }
147
148 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
150 self.vtable.return_len(args)
151 }
152
153 pub fn is_elementwise(&self) -> bool {
155 self.vtable.is_elementwise()
157 }
158
159 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
161 self.kernels.read().to_vec()
162 }
163}
164
165pub trait ComputeFnVTable: 'static + Send + Sync {
167 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
173 -> VortexResult<Output>;
174
175 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
179
180 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
185
186 fn is_elementwise(&self) -> bool;
194}
195
196#[derive(Clone)]
198pub struct InvocationArgs<'a> {
199 pub inputs: &'a [Input<'a>],
200 pub options: &'a dyn Options,
201}
202
203pub struct UnaryArgs<'a, O: Options> {
205 pub array: &'a dyn Array,
206 pub options: &'a O,
207}
208
209impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
210 type Error = VortexError;
211
212 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
213 if value.inputs.len() != 1 {
214 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
215 }
216 let array = value.inputs[0]
217 .array()
218 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
219 let options =
220 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
221 vortex_err!("Expected options to be of type {}", type_name::<O>())
222 })?;
223 Ok(UnaryArgs { array, options })
224 }
225}
226
227pub struct BinaryArgs<'a, O: Options> {
229 pub lhs: &'a dyn Array,
230 pub rhs: &'a dyn Array,
231 pub options: &'a O,
232}
233
234impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
235 type Error = VortexError;
236
237 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
238 if value.inputs.len() != 2 {
239 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
240 }
241 let lhs = value.inputs[0]
242 .array()
243 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
244 let rhs = value.inputs[1]
245 .array()
246 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
247 let options =
248 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
249 vortex_err!("Expected options to be of type {}", type_name::<O>())
250 })?;
251 Ok(BinaryArgs { lhs, rhs, options })
252 }
253}
254
255pub enum Input<'a> {
257 Scalar(&'a Scalar),
258 Array(&'a dyn Array),
259 Mask(&'a Mask),
260 Builder(&'a mut dyn ArrayBuilder),
261 DType(&'a DType),
262}
263
264impl Debug for Input<'_> {
265 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266 let mut f = f.debug_struct("Input");
267 match self {
268 Input::Scalar(scalar) => f.field("Scalar", scalar),
269 Input::Array(array) => f.field("Array", array),
270 Input::Mask(mask) => f.field("Mask", mask),
271 Input::Builder(builder) => f.field("Builder", &builder.len()),
272 Input::DType(dtype) => f.field("DType", dtype),
273 };
274 f.finish()
275 }
276}
277
278impl<'a> From<&'a dyn Array> for Input<'a> {
279 fn from(value: &'a dyn Array) -> Self {
280 Input::Array(value)
281 }
282}
283
284impl<'a> From<&'a Scalar> for Input<'a> {
285 fn from(value: &'a Scalar) -> Self {
286 Input::Scalar(value)
287 }
288}
289
290impl<'a> From<&'a Mask> for Input<'a> {
291 fn from(value: &'a Mask) -> Self {
292 Input::Mask(value)
293 }
294}
295
296impl<'a> From<&'a DType> for Input<'a> {
297 fn from(value: &'a DType) -> Self {
298 Input::DType(value)
299 }
300}
301
302impl<'a> Input<'a> {
303 pub fn scalar(&self) -> Option<&'a Scalar> {
304 if let Input::Scalar(scalar) = self {
305 Some(*scalar)
306 } else {
307 None
308 }
309 }
310
311 pub fn array(&self) -> Option<&'a dyn Array> {
312 if let Input::Array(array) = self {
313 Some(*array)
314 } else {
315 None
316 }
317 }
318
319 pub fn mask(&self) -> Option<&'a Mask> {
320 if let Input::Mask(mask) = self {
321 Some(*mask)
322 } else {
323 None
324 }
325 }
326
327 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
328 if let Input::Builder(builder) = self {
329 Some(*builder)
330 } else {
331 None
332 }
333 }
334
335 pub fn dtype(&self) -> Option<&'a DType> {
336 if let Input::DType(dtype) = self {
337 Some(*dtype)
338 } else {
339 None
340 }
341 }
342}
343
344#[derive(Debug)]
346pub enum Output {
347 Scalar(Scalar),
348 Array(ArrayRef),
349}
350
351#[allow(clippy::len_without_is_empty)]
352impl Output {
353 pub fn dtype(&self) -> &DType {
354 match self {
355 Output::Scalar(scalar) => scalar.dtype(),
356 Output::Array(array) => array.dtype(),
357 }
358 }
359
360 pub fn len(&self) -> usize {
361 match self {
362 Output::Scalar(_) => 1,
363 Output::Array(array) => array.len(),
364 }
365 }
366
367 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
368 match self {
369 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
370 Output::Scalar(scalar) => Ok(scalar),
371 }
372 }
373
374 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
375 match self {
376 Output::Array(array) => Ok(array),
377 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
378 }
379 }
380}
381
382impl From<ArrayRef> for Output {
383 fn from(value: ArrayRef) -> Self {
384 Output::Array(value)
385 }
386}
387
388impl From<Scalar> for Output {
389 fn from(value: Scalar) -> Self {
390 Output::Scalar(value)
391 }
392}
393
394pub trait Options: 'static {
396 fn as_any(&self) -> &dyn Any;
397}
398
399impl Options for () {
400 fn as_any(&self) -> &dyn Any {
401 self
402 }
403}
404
405pub trait Kernel: 'static + Send + Sync + Debug {
415 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
417}
418
419#[macro_export]
422macro_rules! register_kernel {
423 ($T:expr) => {
424 $crate::aliases::inventory::submit!($T);
425 };
426}