1use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11use std::sync::RwLock;
12
13use arcref::ArcRef;
14pub use between::*;
15pub use boolean::*;
16pub use cast::*;
17pub use compare::*;
18pub use fill_null::*;
19pub use filter::*;
20pub use invert::*;
21pub use is_constant::*;
22pub use is_sorted::*;
23use itertools::Itertools;
24pub use like::*;
25pub use list::*;
26pub use mask::*;
27pub use min_max::*;
28pub use nan_count::*;
29pub use numeric::*;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62pub struct ComputeFn {
65 id: ArcRef<str>,
66 vtable: ArcRef<dyn ComputeFnVTable>,
67 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73 Self {
74 id,
75 vtable,
76 kernels: Default::default(),
77 }
78 }
79
80 pub fn id(&self) -> &ArcRef<str> {
82 &self.id
83 }
84
85 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87 self.kernels
88 .write()
89 .vortex_expect("poisoned lock")
90 .push(kernel);
91 }
92
93 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
95 if self.is_elementwise() {
97 if !args
99 .inputs
100 .iter()
101 .filter_map(|input| input.array())
102 .map(|array| array.len())
103 .all_equal()
104 {
105 vortex_bail!(
106 "Compute function {} is elementwise but input arrays have different lengths",
107 self.id
108 );
109 }
110 }
111
112 let expected_dtype = self.vtable.return_dtype(args)?;
113 let expected_len = self.vtable.return_len(args)?;
114
115 let output = self
116 .vtable
117 .invoke(args, &self.kernels.read().vortex_expect("poisoned lock"))?;
118
119 if output.dtype() != &expected_dtype {
120 vortex_bail!(
121 "Internal error: compute function {} returned a result of type {} but expected {}",
122 self.id,
123 output.dtype(),
124 &expected_dtype
125 );
126 }
127 if output.len() != expected_len {
128 vortex_bail!(
129 "Internal error: compute function {} returned a result of length {} but expected {}",
130 self.id,
131 output.len(),
132 expected_len
133 );
134 }
135
136 Ok(output)
137 }
138
139 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
141 self.vtable.return_dtype(args)
142 }
143
144 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
146 self.vtable.return_len(args)
147 }
148
149 pub fn is_elementwise(&self) -> bool {
151 self.vtable.is_elementwise()
153 }
154
155 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
157 self.kernels.read().vortex_expect("poisoned lock").to_vec()
158 }
159}
160
161pub trait ComputeFnVTable: 'static + Send + Sync {
163 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
169 -> VortexResult<Output>;
170
171 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
175
176 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
181
182 fn is_elementwise(&self) -> bool;
190}
191
192#[derive(Clone)]
194pub struct InvocationArgs<'a> {
195 pub inputs: &'a [Input<'a>],
196 pub options: &'a dyn Options,
197}
198
199pub struct UnaryArgs<'a, O: Options> {
201 pub array: &'a dyn Array,
202 pub options: &'a O,
203}
204
205impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
206 type Error = VortexError;
207
208 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
209 if value.inputs.len() != 1 {
210 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
211 }
212 let array = value.inputs[0]
213 .array()
214 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
215 let options =
216 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
217 vortex_err!("Expected options to be of type {}", type_name::<O>())
218 })?;
219 Ok(UnaryArgs { array, options })
220 }
221}
222
223pub struct BinaryArgs<'a, O: Options> {
225 pub lhs: &'a dyn Array,
226 pub rhs: &'a dyn Array,
227 pub options: &'a O,
228}
229
230impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
231 type Error = VortexError;
232
233 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
234 if value.inputs.len() != 2 {
235 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
236 }
237 let lhs = value.inputs[0]
238 .array()
239 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
240 let rhs = value.inputs[1]
241 .array()
242 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
243 let options =
244 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
245 vortex_err!("Expected options to be of type {}", type_name::<O>())
246 })?;
247 Ok(BinaryArgs { lhs, rhs, options })
248 }
249}
250
251pub enum Input<'a> {
253 Scalar(&'a Scalar),
254 Array(&'a dyn Array),
255 Mask(&'a Mask),
256 Builder(&'a mut dyn ArrayBuilder),
257 DType(&'a DType),
258}
259
260impl Debug for Input<'_> {
261 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
262 let mut f = f.debug_struct("Input");
263 match self {
264 Input::Scalar(scalar) => f.field("Scalar", scalar),
265 Input::Array(array) => f.field("Array", array),
266 Input::Mask(mask) => f.field("Mask", mask),
267 Input::Builder(builder) => f.field("Builder", &builder.len()),
268 Input::DType(dtype) => f.field("DType", dtype),
269 };
270 f.finish()
271 }
272}
273
274impl<'a> From<&'a dyn Array> for Input<'a> {
275 fn from(value: &'a dyn Array) -> Self {
276 Input::Array(value)
277 }
278}
279
280impl<'a> From<&'a Scalar> for Input<'a> {
281 fn from(value: &'a Scalar) -> Self {
282 Input::Scalar(value)
283 }
284}
285
286impl<'a> From<&'a Mask> for Input<'a> {
287 fn from(value: &'a Mask) -> Self {
288 Input::Mask(value)
289 }
290}
291
292impl<'a> From<&'a DType> for Input<'a> {
293 fn from(value: &'a DType) -> Self {
294 Input::DType(value)
295 }
296}
297
298impl<'a> Input<'a> {
299 pub fn scalar(&self) -> Option<&'a Scalar> {
300 match self {
301 Input::Scalar(scalar) => Some(*scalar),
302 _ => None,
303 }
304 }
305
306 pub fn array(&self) -> Option<&'a dyn Array> {
307 match self {
308 Input::Array(array) => Some(*array),
309 _ => None,
310 }
311 }
312
313 pub fn mask(&self) -> Option<&'a Mask> {
314 match self {
315 Input::Mask(mask) => Some(*mask),
316 _ => None,
317 }
318 }
319
320 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
321 match self {
322 Input::Builder(builder) => Some(*builder),
323 _ => None,
324 }
325 }
326
327 pub fn dtype(&self) -> Option<&'a DType> {
328 match self {
329 Input::DType(dtype) => Some(*dtype),
330 _ => None,
331 }
332 }
333}
334
335#[derive(Debug)]
337pub enum Output {
338 Scalar(Scalar),
339 Array(ArrayRef),
340}
341
342#[allow(clippy::len_without_is_empty)]
343impl Output {
344 pub fn dtype(&self) -> &DType {
345 match self {
346 Output::Scalar(scalar) => scalar.dtype(),
347 Output::Array(array) => array.dtype(),
348 }
349 }
350
351 pub fn len(&self) -> usize {
352 match self {
353 Output::Scalar(_) => 1,
354 Output::Array(array) => array.len(),
355 }
356 }
357
358 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
359 match self {
360 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
361 Output::Scalar(scalar) => Ok(scalar),
362 }
363 }
364
365 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
366 match self {
367 Output::Array(array) => Ok(array),
368 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
369 }
370 }
371}
372
373impl From<ArrayRef> for Output {
374 fn from(value: ArrayRef) -> Self {
375 Output::Array(value)
376 }
377}
378
379impl From<Scalar> for Output {
380 fn from(value: Scalar) -> Self {
381 Output::Scalar(value)
382 }
383}
384
385pub trait Options: 'static {
387 fn as_any(&self) -> &dyn Any;
388}
389
390impl Options for () {
391 fn as_any(&self) -> &dyn Any {
392 self
393 }
394}
395
396pub trait Kernel: 'static + Send + Sync + Debug {
406 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
408}
409
410#[macro_export]
413macro_rules! register_kernel {
414 ($T:expr) => {
415 $crate::aliases::inventory::submit!($T);
416 };
417}